diff --git a/connector/protobuf/pom.xml b/connector/protobuf/pom.xml index b934c7f831a2..36c9efb3d8c3 100644 --- a/connector/protobuf/pom.xml +++ b/connector/protobuf/pom.xml @@ -113,20 +113,54 @@ com.github.os72 protoc-jar-maven-plugin 3.11.4 - + generate-test-sources + v2-protos run com.google.protobuf:protoc:${protobuf.version} ${protobuf.version} - - src/test/resources/protobuf - - test + src/test/protobuf/v2 + + + java + test + + + descriptor + test + target/scala-${scala.binary.version}/test-classes + descriptor-set-v2 + + + + + + generate-test-sources + v3-protos + + run + + + com.google.protobuf:protoc:${protobuf.version} + ${protobuf.version} + src/test/protobuf/v3 + + + java + test + + + descriptor + test + target/scala-${scala.binary.version}/test-classes + descriptor-set-v3 + + diff --git a/connector/protobuf/src/test/protobuf/v2/catalyst_types.proto b/connector/protobuf/src/test/protobuf/v2/catalyst_types.proto new file mode 100644 index 000000000000..f2e11486a448 --- /dev/null +++ b/connector/protobuf/src/test/protobuf/v2/catalyst_types.proto @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto2"; + +package org.apache.spark.sql.protobuf.protos.v2; +option java_outer_classname = "CatalystTypes"; + +// TODO: import one or more protobuf files. + +message BooleanMsg { + optional bool bool_type = 1; +} +message IntegerMsg { + optional int32 int32_type = 1; +} +message DoubleMsg { + optional double double_type = 1; +} +message FloatMsg { + optional float float_type = 1; +} +message BytesMsg { + optional bytes bytes_type = 1; +} +message StringMsg { + optional string string_type = 1; +} + +message Person { + optional string name = 1; + optional int32 age = 2; +} + +message Bad { + optional bytes col_0 = 1; + optional double col_1 = 2; + optional string col_2 = 3; + optional float col_3 = 4; + optional int64 col_4 = 5; +} + +message Actual { + optional string col_0 = 1; + optional int32 col_1 = 2; + optional float col_2 = 3; + optional bool col_3 = 4; + optional double col_4 = 5; +} + +message oldConsumer { + optional string key = 1; +} + +message newProducer { + optional string key = 1; + optional int32 value = 2; +} + +message newConsumer { + optional string key = 1; + optional int32 value = 2; + optional Actual actual = 3; +} + +message oldProducer { + optional string key = 1; +} diff --git a/connector/protobuf/src/test/protobuf/v2/functions_suite.proto b/connector/protobuf/src/test/protobuf/v2/functions_suite.proto new file mode 100644 index 000000000000..0002ebe55f35 --- /dev/null +++ b/connector/protobuf/src/test/protobuf/v2/functions_suite.proto @@ -0,0 +1,194 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto2"; + +package org.apache.spark.sql.protobuf.protos.v2; + +option java_outer_classname = "SimpleMessageProtos"; + +message SimpleMessageJavaTypes { + optional int64 id = 1; + optional string string_value = 2; + optional int32 int32_value = 3; + optional int64 int64_value = 4; + optional double double_value = 5; + optional float float_value = 6; + optional bool bool_value = 7; + optional bytes bytes_value = 8; +} + +message SimpleMessage { + optional int64 id = 1; + optional string string_value = 2; + optional int32 int32_value = 3; + optional uint32 uint32_value = 4; + optional sint32 sint32_value = 5; + optional fixed32 fixed32_value = 6; + optional sfixed32 sfixed32_value = 7; + optional int64 int64_value = 8; + optional uint64 uint64_value = 9; + optional sint64 sint64_value = 10; + optional fixed64 fixed64_value = 11; + optional sfixed64 sfixed64_value = 12; + optional double double_value = 13; + optional float float_value = 14; + optional bool bool_value = 15; + optional bytes bytes_value = 16; +} + +message SimpleMessageRepeated { + optional string key = 1; + optional string value = 2; + enum NestedEnum { + ESTED_NOTHING = 0; + NESTED_FIRST = 1; + NESTED_SECOND = 2; + } + repeated string rstring_value = 3; + repeated int32 rint32_value = 4; + repeated bool rbool_value = 5; + repeated int64 rint64_value = 6; + repeated float rfloat_value = 7; + repeated double rdouble_value = 8; + repeated bytes rbytes_value = 9; + repeated NestedEnum rnested_enum = 10; +} + +message BasicMessage { + optional int64 id = 1; + optional string string_value = 2; + optional int32 int32_value = 3; + optional int64 int64_value = 4; + optional double double_value = 5; + optional float float_value = 6; + optional bool bool_value = 7; + optional bytes bytes_value = 8; +} + +message RepeatedMessage { + repeated BasicMessage basic_message = 1; +} + +message SimpleMessageMap { + optional string key = 1; + optional string value = 2; + map string_mapdata = 3; + map int32_mapdata = 4; + map uint32_mapdata = 5; + map sint32_mapdata = 6; + map float32_mapdata = 7; + map sfixed32_mapdata = 8; + map int64_mapdata = 9; + map uint64_mapdata = 10; + map sint64_mapdata = 11; + map fixed64_mapdata = 12; + map sfixed64_mapdata = 13; + map double_mapdata = 14; + map float_mapdata = 15; + map bool_mapdata = 16; + map bytes_mapdata = 17; +} + +message BasicEnumMessage { + enum BasicEnum { + NOTHING = 0; + FIRST = 1; + SECOND = 2; + } +} + +message SimpleMessageEnum { + optional string key = 1; + optional string value = 2; + enum NestedEnum { + NESTED_NOTHING = 0; + NESTED_FIRST = 1; + NESTED_SECOND = 2; + } + optional BasicEnumMessage.BasicEnum basic_enum = 3; + optional NestedEnum nested_enum = 4; +} + + +message OtherExample { + optional string other = 1; +} + +message IncludedExample { + optional string included = 1; + optional OtherExample other = 2; +} + +message MultipleExample { + optional IncludedExample included_example = 1; +} + +message recursiveA { + optional string keyA = 1; + optional recursiveB messageB = 2; +} + +message recursiveB { + optional string keyB = 1; + optional recursiveA messageA = 2; +} + +message recursiveC { + optional string keyC = 1; + optional recursiveD messageD = 2; +} + +message recursiveD { + optional string keyD = 1; + repeated recursiveC messageC = 2; +} + +message requiredMsg { + optional string key = 1; + optional int32 col_1 = 2; + optional string col_2 = 3; + optional int32 col_3 = 4; +} + +// https://github.com/protocolbuffers/protobuf/blob/main/src/google/protobuf/timestamp.proto +message Timestamp { + optional int64 seconds = 1; + optional int32 nanos = 2; +} + +message timeStampMsg { + optional string key = 1; + optional Timestamp stmp = 2; +} +// https://github.com/protocolbuffers/protobuf/blob/main/src/google/protobuf/duration.proto +message Duration { + optional int64 seconds = 1; + optional int32 nanos = 2; +} + +message durationMsg { + optional string key = 1; + optional Duration duration = 2; +} + +message ProtoWithDefaults { + optional string user_name = 1; + required int32 id = 2; + optional int32 api_quota = 3 [default = 100]; // Default 100 qps. + optional string location = 4 [default = "Unknown"]; +} \ No newline at end of file diff --git a/connector/protobuf/src/test/protobuf/v2/pyspark_test.proto b/connector/protobuf/src/test/protobuf/v2/pyspark_test.proto new file mode 100644 index 000000000000..336c86cf7e79 --- /dev/null +++ b/connector/protobuf/src/test/protobuf/v2/pyspark_test.proto @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto2"; + +package org.apache.spark.sql.protobuf; +option java_outer_classname = "SimpleMessageProtos"; + +message SimpleMessage { + optional int32 age = 1; + optional string name = 2; + optional int64 score = 3; +} \ No newline at end of file diff --git a/connector/protobuf/src/test/protobuf/v2/serde_suite.proto b/connector/protobuf/src/test/protobuf/v2/serde_suite.proto new file mode 100644 index 000000000000..c401766c9f0d --- /dev/null +++ b/connector/protobuf/src/test/protobuf/v2/serde_suite.proto @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto2"; + +package org.apache.spark.sql.protobuf.protos.v2; +option java_outer_classname = "SerdeSuiteProtos"; + +/* Clean Message*/ +message SerdeBasicMessage { + optional Foo foo = 1; +} + +message Foo { + optional int32 bar = 1; +} + +/* Field Type missMatch in root Message*/ +message MissMatchTypeInRoot { + optional int64 foo = 1; +} + +/* Field bar missing from protobuf and Available in SQL*/ +message FieldMissingInProto { + optional MissingField foo = 1; +} + +message MissingField { + optional int64 barFoo = 1; +} + +/* Deep-nested field bar type missMatch Message*/ +message MissMatchTypeInDeepNested { + optional TypeMissNested top = 1; +} + +message TypeMissNested { + optional TypeMiss foo = 1; +} + +message TypeMiss { + optional int64 bar = 1; +} + +/* Field boo missing from SQL root, but available in Protobuf root*/ +message FieldMissingInSQLRoot { + optional Foo foo = 1; + optional int32 boo = 2; +} + +/* Field baz missing from SQL nested and available in Protobuf nested*/ +message FieldMissingInSQLNested { + optional Baz foo = 1; +} + +message Baz { + optional int32 bar = 1; + optional int32 baz = 2; +} \ No newline at end of file diff --git a/connector/protobuf/src/test/resources/protobuf/catalyst_types.proto b/connector/protobuf/src/test/protobuf/v3/catalyst_types.proto similarity index 79% rename from connector/protobuf/src/test/resources/protobuf/catalyst_types.proto rename to connector/protobuf/src/test/protobuf/v3/catalyst_types.proto index 1deb193438c2..f8db4f382a5a 100644 --- a/connector/protobuf/src/test/resources/protobuf/catalyst_types.proto +++ b/connector/protobuf/src/test/protobuf/v3/catalyst_types.proto @@ -14,12 +14,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -// protoc --java_out=connector/protobuf/src/test/resources/protobuf/ connector/protobuf/src/test/resources/protobuf/catalyst_types.proto -// protoc --descriptor_set_out=connector/protobuf/src/test/resources/protobuf/catalyst_types.desc --java_out=connector/protobuf/src/test/resources/protobuf/org/apache/spark/sql/protobuf/ connector/protobuf/src/test/resources/protobuf/catalyst_types.proto syntax = "proto3"; -package org.apache.spark.sql.protobuf.protos; +package org.apache.spark.sql.protobuf.protos.v3; option java_outer_classname = "CatalystTypes"; // TODO: import one or more protobuf files. diff --git a/connector/protobuf/src/test/resources/protobuf/functions_suite.proto b/connector/protobuf/src/test/protobuf/v3/functions_suite.proto similarity index 90% rename from connector/protobuf/src/test/resources/protobuf/functions_suite.proto rename to connector/protobuf/src/test/protobuf/v3/functions_suite.proto index 60f8c2621415..f394500e1465 100644 --- a/connector/protobuf/src/test/resources/protobuf/functions_suite.proto +++ b/connector/protobuf/src/test/protobuf/v3/functions_suite.proto @@ -14,13 +14,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -// To compile and create test class: -// protoc --java_out=connector/protobuf/src/test/resources/protobuf/ connector/protobuf/src/test/resources/protobuf/functions_suite.proto -// protoc --descriptor_set_out=connector/protobuf/src/test/resources/protobuf/functions_suite.desc --java_out=connector/protobuf/src/test/resources/protobuf/org/apache/spark/sql/protobuf/ connector/protobuf/src/test/resources/protobuf/functions_suite.proto syntax = "proto3"; -package org.apache.spark.sql.protobuf.protos; +package org.apache.spark.sql.protobuf.protos.v3; option java_outer_classname = "SimpleMessageProtos"; diff --git a/connector/protobuf/src/test/resources/protobuf/pyspark_test.proto b/connector/protobuf/src/test/protobuf/v3/pyspark_test.proto similarity index 66% rename from connector/protobuf/src/test/resources/protobuf/pyspark_test.proto rename to connector/protobuf/src/test/protobuf/v3/pyspark_test.proto index 8750371349a0..f48fde9010b7 100644 --- a/connector/protobuf/src/test/resources/protobuf/pyspark_test.proto +++ b/connector/protobuf/src/test/protobuf/v3/pyspark_test.proto @@ -14,17 +14,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -// TODO(SPARK-40777): Instead of saving .desc files in resources, generate during build. -// To compile and create test class: -// protoc --java_out=connector/protobuf/src/test/resources/protobuf/ connector/protobuf/src/test/resources/protobuf/pyspark_test.proto -// protoc --descriptor_set_out=connector/protobuf/src/test/resources/protobuf/pyspark_test.desc --java_out=connector/protobuf/src/test/resources/protobuf/org/apache/spark/sql/protobuf/ connector/protobuf/src/test/resources/protobuf/pyspark_test.proto syntax = "proto3"; package org.apache.spark.sql.protobuf; option java_outer_classname = "SimpleMessageProtos"; - message SimpleMessage { int32 age = 1; string name = 2; diff --git a/connector/protobuf/src/test/resources/protobuf/serde_suite.proto b/connector/protobuf/src/test/protobuf/v3/serde_suite.proto similarity index 78% rename from connector/protobuf/src/test/resources/protobuf/serde_suite.proto rename to connector/protobuf/src/test/protobuf/v3/serde_suite.proto index a7459213a87b..be9eb6c815f7 100644 --- a/connector/protobuf/src/test/resources/protobuf/serde_suite.proto +++ b/connector/protobuf/src/test/protobuf/v3/serde_suite.proto @@ -14,13 +14,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -// To compile and create test class: -// protoc --java_out=connector/protobuf/src/test/resources/protobuf/ connector/protobuf/src/test/resources/protobuf/serde_suite.proto -// protoc --descriptor_set_out=connector/protobuf/src/test/resources/protobuf/serde_suite.desc --java_out=connector/protobuf/src/test/resources/protobuf/org/apache/spark/sql/protobuf/ connector/protobuf/src/test/resources/protobuf/serde_suite.proto syntax = "proto3"; -package org.apache.spark.sql.protobuf.protos; +package org.apache.spark.sql.protobuf.protos.v3; option java_outer_classname = "SerdeSuiteProtos"; /* Clean Message*/ diff --git a/connector/protobuf/src/test/resources/protobuf/catalyst_types.desc b/connector/protobuf/src/test/resources/protobuf/catalyst_types.desc deleted file mode 100644 index 59255b488a03..000000000000 --- a/connector/protobuf/src/test/resources/protobuf/catalyst_types.desc +++ /dev/null @@ -1,48 +0,0 @@ - -‰ -Cconnector/protobuf/src/test/resources/protobuf/catalyst_types.protoorg.apache.spark.sql.protobuf") - -BooleanMsg - bool_type (RboolType"+ - -IntegerMsg - -int32_type (R int32Type", - DoubleMsg - double_type (R -doubleType") -FloatMsg - -float_type (R floatType") -BytesMsg - -bytes_type ( R bytesType", - StringMsg - string_type ( R -stringType". -Person -name ( Rname -age (Rage"n -Bad -col_0 ( Rcol0 -col_1 (Rcol1 -col_2 ( Rcol2 -col_3 (Rcol3 -col_4 (Rcol4"q -Actual -col_0 ( Rcol0 -col_1 (Rcol1 -col_2 (Rcol2 -col_3 (Rcol3 -col_4 (Rcol4" - oldConsumer -key ( Rkey"5 - newProducer -key ( Rkey -value (Rvalue"t - newConsumer -key ( Rkey -value (Rvalue= -actual ( 2%.org.apache.spark.sql.protobuf.ActualRactual" - oldProducer -key ( RkeyBB CatalystTypesbproto3 \ No newline at end of file diff --git a/connector/protobuf/src/test/resources/protobuf/functions_suite.desc b/connector/protobuf/src/test/resources/protobuf/functions_suite.desc deleted file mode 100644 index 6e3a39672772..000000000000 Binary files a/connector/protobuf/src/test/resources/protobuf/functions_suite.desc and /dev/null differ diff --git a/connector/protobuf/src/test/resources/protobuf/serde_suite.desc b/connector/protobuf/src/test/resources/protobuf/serde_suite.desc deleted file mode 100644 index 3d1847eecc5c..000000000000 --- a/connector/protobuf/src/test/resources/protobuf/serde_suite.desc +++ /dev/null @@ -1,27 +0,0 @@ - -² -Fconnector/protobuf/src/test/resources/protobuf/proto_serde_suite.protoorg.apache.spark.sql.protobuf"D - BasicMessage4 -foo ( 2".org.apache.spark.sql.protobuf.FooRfoo" -Foo -bar (Rbar"' -MissMatchTypeInRoot -foo (Rfoo"T -FieldMissingInProto= -foo ( 2+.org.apache.spark.sql.protobuf.MissingFieldRfoo"& - MissingField -barFoo (RbarFoo"\ -MissMatchTypeInDeepNested? -top ( 2-.org.apache.spark.sql.protobuf.TypeMissNestedRtop"K -TypeMissNested9 -foo ( 2'.org.apache.spark.sql.protobuf.TypeMissRfoo" -TypeMiss -bar (Rbar"_ -FieldMissingInSQLRoot4 -foo ( 2".org.apache.spark.sql.protobuf.FooRfoo -boo (Rboo"O -FieldMissingInSQLNested4 -foo ( 2".org.apache.spark.sql.protobuf.BazRfoo") -Baz -bar (Rbar -baz (RbazBBSimpleMessageProtosbproto3 \ No newline at end of file diff --git a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufCatalystDataConversionSuite.scala b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufCatalystDataConversionSuite.scala index 19774a2ad07e..9484bb8ecce6 100644 --- a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufCatalystDataConversionSuite.scala +++ b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufCatalystDataConversionSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.{RandomDataGenerator, Row} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, NoopFilters, OrderedFilters, StructFilters} import org.apache.spark.sql.catalyst.expressions.{ExpressionEvalHelper, GenericInternalRow, Literal} import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, MapData} -import org.apache.spark.sql.protobuf.protos.CatalystTypes.BytesMsg +import org.apache.spark.sql.protobuf.protos.v3.CatalystTypes.BytesMsg import org.apache.spark.sql.protobuf.utils.{ProtobufUtils, SchemaConverters} import org.apache.spark.sql.sources.{EqualTo, Not} import org.apache.spark.sql.test.SharedSparkSession @@ -33,34 +33,41 @@ import org.apache.spark.unsafe.types.UTF8String class ProtobufCatalystDataConversionSuite extends SparkFunSuite + with ProtobufTestUtils with SharedSparkSession with ExpressionEvalHelper { - private val testFileDesc = testFile("protobuf/catalyst_types.desc").replace("file:/", "/") - private val javaClassNamePrefix = "org.apache.spark.sql.protobuf.protos.CatalystTypes$" + private val descPathV2 = descFilePathV2("catalyst_types.desc") + private val descPathV3 = descFilePathV3("catalyst_types.desc") + private val outerClass = "CatalystTypes" private def checkResultWithEval( data: Literal, - descFilePath: String, messageName: String, expected: Any): Unit = { - withClue("(Eval check with Java class name)") { - val className = s"$javaClassNamePrefix$messageName" - checkEvaluation( - ProtobufDataToCatalyst( - CatalystDataToProtobuf(data, className), - className, - descFilePath = None), - prepareExpectedResult(expected)) - } - withClue("(Eval check with descriptor file)") { - checkEvaluation( - ProtobufDataToCatalyst( - CatalystDataToProtobuf(data, messageName, Some(descFilePath)), - messageName, - descFilePath = Some(descFilePath)), - prepareExpectedResult(expected)) + // Check all the 16 variations. + // 4 variations for how the record is written and 4 variations for how it is read. + + val variations = Seq( + (messageName, Some(descPathV2), "V2 descriptor file"), + (javaClassFullNameV2(outerClass, messageName), None, "V2 Java class"), + (messageName, Some(descPathV3), "V3 descriptor file"), + (javaClassFullNameV3(outerClass, messageName), None, "V3 Java class") + ) + + for ( // 16 permutations of 4 variations above. + (writerMessageName, writerDesc, writerClue) <- variations; + (readerMessageName, readerDesc, readerClue) <- variations + ) { + withClue(s"(Eval check with $writerClue for writing $readerClue for reading") { + checkEvaluation( + ProtobufDataToCatalyst( + CatalystDataToProtobuf(data, writerMessageName, writerDesc), + readerMessageName, + descFilePath = readerDesc), + prepareExpectedResult(expected)) + } } } @@ -139,7 +146,6 @@ class ProtobufCatalystDataConversionSuite checkResultWithEval( input, - testFileDesc, messageName, input.eval()) } @@ -162,7 +168,7 @@ class ProtobufCatalystDataConversionSuite // Verify Java class deserializer matches with descriptor based serializer. val javaDescriptor = ProtobufUtils - .buildDescriptorFromJavaClass(s"$javaClassNamePrefix$messageName") + .buildDescriptorFromJavaClass(javaClassFullNameV3(outerClass, messageName)) assert(dataType == SchemaConverters.toSqlType(javaDescriptor).dataType) val javaDeserialized = new ProtobufDeserializer(javaDescriptor, dataType, filters) .deserialize(DynamicMessage.parseFrom(javaDescriptor, data.toByteArray)) @@ -189,7 +195,7 @@ class ProtobufCatalystDataConversionSuite val data = RandomDataGenerator.randomRow(new scala.util.Random(seed), actualSchema) val converter = CatalystTypeConverters.createToCatalystConverter(actualSchema) val input = Literal.create(converter(data), actualSchema) - checkUnsupportedRead(input, testFileDesc, "Actual", "Bad") + checkUnsupportedRead(input, descPathV3, "Actual", "Bad") } } @@ -199,7 +205,7 @@ class ProtobufCatalystDataConversionSuite .add("name", "string") .add("age", "int") - val descriptor = ProtobufUtils.buildDescriptor(testFileDesc, "Person") + val descriptor = ProtobufUtils.buildDescriptor(descPathV3, "Person") val dynamicMessage = DynamicMessage .newBuilder(descriptor) .setField(descriptor.findFieldByName("name"), "Maxim") @@ -207,16 +213,16 @@ class ProtobufCatalystDataConversionSuite .build() val expectedRow = Some(InternalRow(UTF8String.fromString("Maxim"), 39)) - checkDeserialization(testFileDesc, "Person", dynamicMessage, expectedRow) + checkDeserialization(descPathV3, "Person", dynamicMessage, expectedRow) checkDeserialization( - testFileDesc, + descPathV3, "Person", dynamicMessage, expectedRow, new OrderedFilters(Seq(EqualTo("age", 39)), sqlSchema)) checkDeserialization( - testFileDesc, + descPathV3, "Person", dynamicMessage, None, @@ -233,15 +239,15 @@ class ProtobufCatalystDataConversionSuite .build() val expected = InternalRow(Array[Byte](97, 48, 53)) - checkDeserialization(testFileDesc, "BytesMsg", bytesProto, Some(expected)) + checkDeserialization(descPathV3, "BytesMsg", bytesProto, Some(expected)) } test("Full names for message using descriptor file") { - val withShortName = ProtobufUtils.buildDescriptor(testFileDesc, "BytesMsg") + val withShortName = ProtobufUtils.buildDescriptor(descPathV3, "BytesMsg") assert(withShortName.findFieldByName("bytes_type") != null) val withFullName = ProtobufUtils.buildDescriptor( - testFileDesc, "org.apache.spark.sql.protobuf.BytesMsg") + descPathV3, "org.apache.spark.sql.protobuf.protos.v3.BytesMsg") assert(withFullName.findFieldByName("bytes_type") != null) } } diff --git a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala index 72280fb0d9e2..ef7cad250534 100644 --- a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala +++ b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala @@ -25,31 +25,43 @@ import com.google.protobuf.{ByteString, DynamicMessage} import org.apache.spark.sql.{Column, QueryTest, Row} import org.apache.spark.sql.functions.{lit, struct} -import org.apache.spark.sql.protobuf.protos.SimpleMessageProtos.SimpleMessageRepeated -import org.apache.spark.sql.protobuf.protos.SimpleMessageProtos.SimpleMessageRepeated.NestedEnum +import org.apache.spark.sql.protobuf.protos.v3.SimpleMessageProtos.SimpleMessageRepeated +import org.apache.spark.sql.protobuf.protos.v3.SimpleMessageProtos.SimpleMessageRepeated.NestedEnum import org.apache.spark.sql.protobuf.utils.ProtobufUtils import org.apache.spark.sql.protobuf.utils.SchemaConverters.IncompatibleSchemaException import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{DayTimeIntervalType, IntegerType, StringType, StructField, StructType, TimestampType} -class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Serializable { +class ProtobufFunctionsSuite + extends QueryTest + with ProtobufTestUtils + with SharedSparkSession + with Serializable { import testImplicits._ - val testFileDesc = testFile("protobuf/functions_suite.desc").replace("file:/", "/") - private val javaClassNamePrefix = "org.apache.spark.sql.protobuf.protos.SimpleMessageProtos$" + private val descPathV2 = descFilePathV2("functions_suite.desc") + private val descPathV3 = descFilePathV3("functions_suite.desc") + private val outerClass = "SimpleMessageProtos" /** - * Runs the given closure twice. Once with descriptor file and second time with Java class name. + * Runs the given closure 4 times with combination of v2, v3 protos and descriptor files. */ private def checkWithFileAndClassName(messageName: String)( fn: (String, Option[String]) => Unit): Unit = { - withClue("(With descriptor file)") { - fn(messageName, Some(testFileDesc)) - } - withClue("(With Java class name)") { - fn(s"$javaClassNamePrefix$messageName", None) + + val variations = Seq( + (messageName, Some(descPathV2), "V2 descriptor file"), + (javaClassFullNameV2(outerClass, messageName), None, "V2 Java class"), + (messageName, Some(descPathV3), "V3 descriptor file"), + (javaClassFullNameV3(outerClass, messageName), None, "V3 Java class") + ) + + for ((messageName, descOpt, clue) <- variations) { + withClue(s"(With $clue)") { + fn(messageName, descOpt) } + } } // A wrapper to invoke the right variable of from_protobuf() depending on arguments. @@ -124,17 +136,17 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Seri case (name, descFilePathOpt) => val fromProtoDF = df.select( from_protobuf_wrapper($"value", name, descFilePathOpt).as("value_from")) - val toProtoDF = fromProtoDF.select( - to_protobuf_wrapper($"value_from", name, descFilePathOpt).as("value_to")) - val toFromProtoDF = toProtoDF.select( - from_protobuf_wrapper($"value_to", name, descFilePathOpt).as("value_to_from")) - checkAnswer(fromProtoDF.select($"value_from.*"), toFromProtoDF.select($"value_to_from.*")) + val toProtoDF = fromProtoDF.select( + to_protobuf_wrapper($"value_from", name, descFilePathOpt).as("value_to")) + val toFromProtoDF = toProtoDF.select( + from_protobuf_wrapper($"value_to", name, descFilePathOpt).as("value_to_from")) + checkAnswer(fromProtoDF.select($"value_from.*"), toFromProtoDF.select($"value_to_from.*")) } } test("roundtrip in from_protobuf and to_protobuf - Repeated Message Once") { - val repeatedMessageDesc = ProtobufUtils.buildDescriptor(testFileDesc, "RepeatedMessage") - val basicMessageDesc = ProtobufUtils.buildDescriptor(testFileDesc, "BasicMessage") + val repeatedMessageDesc = ProtobufUtils.buildDescriptor(descPathV3, "RepeatedMessage") + val basicMessageDesc = ProtobufUtils.buildDescriptor(descPathV3, "BasicMessage") val basicMessage = DynamicMessage .newBuilder(basicMessageDesc) @@ -170,8 +182,8 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Seri } test("roundtrip in from_protobuf and to_protobuf - Repeated Message Twice") { - val repeatedMessageDesc = ProtobufUtils.buildDescriptor(testFileDesc, "RepeatedMessage") - val basicMessageDesc = ProtobufUtils.buildDescriptor(testFileDesc, "BasicMessage") + val repeatedMessageDesc = ProtobufUtils.buildDescriptor(descPathV2, "RepeatedMessage") + val basicMessageDesc = ProtobufUtils.buildDescriptor(descPathV2, "BasicMessage") val basicMessage1 = DynamicMessage .newBuilder(basicMessageDesc) @@ -221,7 +233,7 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Seri } test("roundtrip in from_protobuf and to_protobuf - Map") { - val messageMapDesc = ProtobufUtils.buildDescriptor(testFileDesc, "SimpleMessageMap") + val messageMapDesc = ProtobufUtils.buildDescriptor(descPathV3, "SimpleMessageMap") val mapStr1 = DynamicMessage .newBuilder(messageMapDesc.findNestedTypeByName("StringMapdataEntry")) @@ -315,8 +327,8 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Seri } test("roundtrip in from_protobuf and to_protobuf - Enum") { - val messageEnumDesc = ProtobufUtils.buildDescriptor(testFileDesc, "SimpleMessageEnum") - val basicEnumDesc = ProtobufUtils.buildDescriptor(testFileDesc, "BasicEnumMessage") + val messageEnumDesc = ProtobufUtils.buildDescriptor(descPathV2, "SimpleMessageEnum") + val basicEnumDesc = ProtobufUtils.buildDescriptor(descPathV2, "BasicEnumMessage") val dynamicMessage = DynamicMessage .newBuilder(messageEnumDesc) @@ -324,7 +336,7 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Seri .setField(messageEnumDesc.findFieldByName("value"), "value") .setField( messageEnumDesc.findFieldByName("nested_enum"), - messageEnumDesc.findEnumTypeByName("NestedEnum").findValueByName("ESTED_NOTHING")) + messageEnumDesc.findEnumTypeByName("NestedEnum").findValueByName("NESTED_NOTHING")) .setField( messageEnumDesc.findFieldByName("nested_enum"), messageEnumDesc.findEnumTypeByName("NestedEnum").findValueByName("NESTED_FIRST")) @@ -351,9 +363,9 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Seri } test("roundtrip in from_protobuf and to_protobuf - Multiple Message") { - val messageMultiDesc = ProtobufUtils.buildDescriptor(testFileDesc, "MultipleExample") - val messageIncludeDesc = ProtobufUtils.buildDescriptor(testFileDesc, "IncludedExample") - val messageOtherDesc = ProtobufUtils.buildDescriptor(testFileDesc, "OtherExample") + val messageMultiDesc = ProtobufUtils.buildDescriptor(descPathV2, "MultipleExample") + val messageIncludeDesc = ProtobufUtils.buildDescriptor(descPathV2, "IncludedExample") + val messageOtherDesc = ProtobufUtils.buildDescriptor(descPathV3, "OtherExample") val otherMessage = DynamicMessage .newBuilder(messageOtherDesc) @@ -386,8 +398,8 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Seri } test("Handle recursive fields in Protobuf schema, A->B->A") { - val schemaA = ProtobufUtils.buildDescriptor(testFileDesc, "recursiveA") - val schemaB = ProtobufUtils.buildDescriptor(testFileDesc, "recursiveB") + val schemaA = ProtobufUtils.buildDescriptor(descPathV3, "recursiveA") + val schemaB = ProtobufUtils.buildDescriptor(descPathV3, "recursiveB") val messageBForA = DynamicMessage .newBuilder(schemaB) @@ -422,8 +434,8 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Seri } test("Handle recursive fields in Protobuf schema, C->D->Array(C)") { - val schemaC = ProtobufUtils.buildDescriptor(testFileDesc, "recursiveC") - val schemaD = ProtobufUtils.buildDescriptor(testFileDesc, "recursiveD") + val schemaC = ProtobufUtils.buildDescriptor(descPathV2, "recursiveC") + val schemaD = ProtobufUtils.buildDescriptor(descPathV2, "recursiveD") val messageDForC = DynamicMessage .newBuilder(schemaD) @@ -458,7 +470,7 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Seri } test("Handle extra fields : oldProducer -> newConsumer") { - val testFileDesc = testFile("protobuf/catalyst_types.desc").replace("file:/", "/") + val testFileDesc = descFilePathV2("catalyst_types.desc") val oldProducer = ProtobufUtils.buildDescriptor(testFileDesc, "oldProducer") val newConsumer = ProtobufUtils.buildDescriptor(testFileDesc, "newConsumer") @@ -498,7 +510,7 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Seri } test("Handle extra fields : newProducer -> oldConsumer") { - val testFileDesc = testFile("protobuf/catalyst_types.desc").replace("file:/", "/") + val testFileDesc = descFilePathV3("catalyst_types.desc") val newProducer = ProtobufUtils.buildDescriptor(testFileDesc, "newProducer") val oldConsumer = ProtobufUtils.buildDescriptor(testFileDesc, "oldConsumer") @@ -541,35 +553,38 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Seri schema ) - val toProtobuf = inputDf.select( - functions.to_protobuf($"requiredMsg", "requiredMsg", testFileDesc) - .as("to_proto")) - - val binary = toProtobuf.take(1).toSeq(0).get(0).asInstanceOf[Array[Byte]] - - val messageDescriptor = ProtobufUtils.buildDescriptor(testFileDesc, "requiredMsg") - val actualMessage = DynamicMessage.parseFrom(messageDescriptor, binary) - - assert(actualMessage.getField(messageDescriptor.findFieldByName("key")) - == inputDf.select("requiredMsg.key").take(1).toSeq(0).get(0)) - assert(actualMessage.getField(messageDescriptor.findFieldByName("col_2")) - == inputDf.select("requiredMsg.col_2").take(1).toSeq(0).get(0)) - assert(actualMessage.getField(messageDescriptor.findFieldByName("col_1")) == 0) - assert(actualMessage.getField(messageDescriptor.findFieldByName("col_3")) == 0) - - val fromProtoDf = toProtobuf.select( - functions.from_protobuf($"to_proto", "requiredMsg", testFileDesc) as 'from_proto) - - assert(fromProtoDf.select("from_proto.key").take(1).toSeq(0).get(0) - == inputDf.select("requiredMsg.key").take(1).toSeq(0).get(0)) - assert(fromProtoDf.select("from_proto.col_2").take(1).toSeq(0).get(0) - == inputDf.select("requiredMsg.col_2").take(1).toSeq(0).get(0)) - assert(fromProtoDf.select("from_proto.col_1").take(1).toSeq(0).get(0) == null) - assert(fromProtoDf.select("from_proto.col_3").take(1).toSeq(0).get(0) == null) + checkWithFileAndClassName("requiredMsg") { + case (name, descFilePathOpt) => + val toProtobuf = inputDf.select( + to_protobuf_wrapper($"requiredMsg", name, descFilePathOpt) + .as("to_proto")) + + val binary = toProtobuf.take(1).toSeq(0).get(0).asInstanceOf[Array[Byte]] + + val messageDescriptor = ProtobufUtils.buildDescriptor(name, descFilePathOpt) + val actualMessage = DynamicMessage.parseFrom(messageDescriptor, binary) + + assert(actualMessage.getField(messageDescriptor.findFieldByName("key")) + == inputDf.select("requiredMsg.key").take(1).toSeq(0).get(0)) + assert(actualMessage.getField(messageDescriptor.findFieldByName("col_2")) + == inputDf.select("requiredMsg.col_2").take(1).toSeq(0).get(0)) + assert(actualMessage.getField(messageDescriptor.findFieldByName("col_1")) == 0) + assert(actualMessage.getField(messageDescriptor.findFieldByName("col_3")) == 0) + + val fromProtoDf = toProtobuf.select( + from_protobuf_wrapper($"to_proto", name, descFilePathOpt) as 'from_proto) + + assert(fromProtoDf.select("from_proto.key").take(1).toSeq(0).get(0) + == inputDf.select("requiredMsg.key").take(1).toSeq(0).get(0)) + assert(fromProtoDf.select("from_proto.col_2").take(1).toSeq(0).get(0) + == inputDf.select("requiredMsg.col_2").take(1).toSeq(0).get(0)) + assert(fromProtoDf.select("from_proto.col_1").take(1).toSeq(0).get(0) == null) + assert(fromProtoDf.select("from_proto.col_3").take(1).toSeq(0).get(0) == null) + } } test("from_protobuf filter to_protobuf") { - val basicMessageDesc = ProtobufUtils.buildDescriptor(testFileDesc, "BasicMessage") + val basicMessageDesc = ProtobufUtils.buildDescriptor(descPathV3, "BasicMessage") val basicMessage = DynamicMessage .newBuilder(basicMessageDesc) @@ -680,4 +695,64 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Seri === inputDf.select("durationMsg.duration").take(1).toSeq(0).get(0)) } } + + test("Default and required values in version 2 protobufs") { + // Verifies that V2 protobufs with default values have correct values returned. + + // Proto: + // message ProtoWithDefaults { + // optional string user_name = 1; + // required int32 id = 2; + // optional int32 api_quota = 3 [default = 100]; // Default 100 qps. + // optional string location = 4 [default = "Unknown"]; + // } + + Seq( // Only V2 variations. This proto does not exist in V3 + ("ProtoWithDefaults", Some(descPathV2), "V2 descriptor file"), + (javaClassFullNameV2(outerClass, "ProtoWithDefaults"), None, "V2 Java class"), + ).foreach { + case (name, descFilePathOpt, clue) => + withClue(clue) { + val schema = Seq[Array[Byte]]() + .toDF("bytes") + .select(from_protobuf_wrapper($"bytes", name, descFilePathOpt).as("proto")) + .schema + + val inputRows = Seq( + Row(Row("user-1", 1, 10, "CA")), // All the fields are set. + Row(Row("user-2", 2, null, null)), // Defaults. + Row(Row(null, 3, 100, "NY")) // Explicit default for "api_quota". + ) + + val inputDf = spark.createDataFrame( + spark.sparkContext.parallelize(inputRows), + schema + ) + + assert( + !schema.head.dataType.asInstanceOf[StructType].fields(1).nullable, + "id should not be nullable" + ) + + val rows = inputDf // Do a round trip with to_protobuf & from_protobuf + .select( + to_protobuf_wrapper($"proto", name, descFilePathOpt).as("proto_bytes") + ) + .select( + from_protobuf_wrapper($"proto_bytes", name, descFilePathOpt).as("proto") + ) + + val results = rows.collect().toSeq.sortBy(_.getStruct(0).getInt(1)) // sort by id. + + // First row should be same as input. + assert(results(0) == inputRows(0)) + + // Second row should have defaults for 'api_quota' and 'location'. + assert(results(1) == Row(Row("user-2", 2, 100, "Unknown"))) + + // Third row should return null for optional field user_name. So same as input. + assert(results(2) == inputRows(2)) + } + } + } } diff --git a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufSerdeSuite.scala b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufSerdeSuite.scala index efc02524e68d..9fc1337476cc 100644 --- a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufSerdeSuite.scala +++ b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufSerdeSuite.scala @@ -30,126 +30,173 @@ import org.apache.spark.sql.types.{IntegerType, StructType} * Tests for [[ProtobufSerializer]] and [[ProtobufDeserializer]] with a more specific focus on * those classes. */ -class ProtobufSerdeSuite extends SharedSparkSession { +class ProtobufSerdeSuite extends SharedSparkSession with ProtobufTestUtils { import ProtoSerdeSuite._ import ProtoSerdeSuite.MatchType._ - val testFileDesc = testFile("protobuf/serde_suite.desc").replace("file:/", "/") - private val javaClassNamePrefix = "org.apache.spark.sql.protobuf.protos.SerdeSuiteProtos$" + private val descPathV2 = descFilePathV2("serde_suite.desc") + private val descPathV3 = descFilePathV3("serde_suite.desc") + private val outerClass = "SerdeSuiteProtos" + + /** + * Runs the given closure 4 times with combination of v2, v3 protos and descriptor files. + */ + private def checkWithFileAndClassName(messageName: String)( + fn: (String, Option[String]) => Unit): Unit = { + + val variations = Seq( + (messageName, Some(descPathV2), "V2 descriptor file"), + (javaClassFullNameV2(outerClass, messageName), None, "V2 Java class"), + (messageName, Some(descPathV3), "V3 descriptor file"), + (javaClassFullNameV3(outerClass, messageName), None, "V3 Java class") + ) + + for ((messageName, descOpt, clue) <- variations) { + withClue(s"(With $clue)") { + fn(messageName, descOpt) + } + } + } + test("Test basic conversion") { withFieldMatchType { fieldMatch => - val (top, nest) = fieldMatch match { - case BY_NAME => ("foo", "bar") - } - val protoFile = ProtobufUtils.buildDescriptor(testFileDesc, "BasicMessage") + checkWithFileAndClassName("SerdeBasicMessage") { + case (name, descPathOpt) => + val protoFile = ProtobufUtils.buildDescriptor(name, descPathOpt) - val dynamicMessageFoo = DynamicMessage - .newBuilder(protoFile.getFile.findMessageTypeByName("Foo")) - .setField(protoFile.getFile.findMessageTypeByName("Foo").findFieldByName("bar"), 10902) - .build() + val dynamicMessageFoo = DynamicMessage + .newBuilder(protoFile.getFile.findMessageTypeByName("Foo")) + .setField(protoFile.getFile.findMessageTypeByName("Foo").findFieldByName("bar"), 10902) + .build() - val dynamicMessage = DynamicMessage - .newBuilder(protoFile) - .setField(protoFile.findFieldByName("foo"), dynamicMessageFoo) - .build() + val dynamicMessage = DynamicMessage + .newBuilder(protoFile) + .setField(protoFile.findFieldByName("foo"), dynamicMessageFoo) + .build() - val serializer = Serializer.create(CATALYST_STRUCT, protoFile, fieldMatch) - val deserializer = Deserializer.create(CATALYST_STRUCT, protoFile, fieldMatch) + val serializer = Serializer.create(CATALYST_STRUCT, protoFile, fieldMatch) + val deserializer = Deserializer.create(CATALYST_STRUCT, protoFile, fieldMatch) - assert( - serializer.serialize(deserializer.deserialize(dynamicMessage).get) === dynamicMessage) + assert( + serializer.serialize(deserializer.deserialize(dynamicMessage).get) === dynamicMessage) + } } } test("Fail to convert with field type mismatch") { - val protoFile = ProtobufUtils.buildDescriptor(testFileDesc, "MissMatchTypeInRoot") - withFieldMatchType { fieldMatch => - assertFailedConversionMessage( - protoFile, - Deserializer, - fieldMatch, - "Cannot convert Protobuf field 'foo' to SQL field 'foo' because schema is incompatible " + - s"(protoType = org.apache.spark.sql.protobuf.MissMatchTypeInRoot.foo " + - s"LABEL_OPTIONAL LONG INT64, sqlType = ${CATALYST_STRUCT.head.dataType.sql})".stripMargin) - - assertFailedConversionMessage( - protoFile, - Serializer, - fieldMatch, - s"Cannot convert SQL field 'foo' to Protobuf field 'foo' because schema is incompatible " + - s"""(sqlType = ${CATALYST_STRUCT.head.dataType.sql}, protoType = LONG)""") + checkWithFileAndClassName("MissMatchTypeInRoot") { + case (name, descPathOpt) => + val protoFile = ProtobufUtils.buildDescriptor(name, descPathOpt) + + val version = descPathOpt match { // A bit hacky, but ok for now. + case Some(path) => if (path == descPathV2) 2 else 3 + case None => if (name.contains(".v2.")) 2 else 3 + } + + assertFailedConversionMessage( + protoFile, + Deserializer, + fieldMatch, + "Cannot convert Protobuf field 'foo' to SQL field 'foo' because schema is " + + s"incompatible (protoType = " + + s"org.apache.spark.sql.protobuf.protos.v$version.MissMatchTypeInRoot.foo " + + s"LABEL_OPTIONAL LONG INT64, sqlType = ${CATALYST_STRUCT.head.dataType.sql})") + + assertFailedConversionMessage( + protoFile, + Serializer, + fieldMatch, + s"Cannot convert SQL field 'foo' to Protobuf field 'foo' because schema is " + + s"incompatible (sqlType = ${CATALYST_STRUCT.head.dataType.sql}, protoType = LONG)") + } } } test("Fail to convert with missing nested Protobuf fields for serializer") { - val protoFile = ProtobufUtils.buildDescriptor(testFileDesc, "FieldMissingInProto") + checkWithFileAndClassName("FieldMissingInProto") { + case (name, descPathOpt) => + val protoFile = ProtobufUtils.buildDescriptor(name, descPathOpt) - val nonnullCatalyst = new StructType() - .add("foo", new StructType().add("bar", IntegerType, nullable = false)) + val nonnullCatalyst = new StructType() + .add("foo", new StructType().add("bar", IntegerType, nullable = false)) - // serialize fails whether or not 'bar' is nullable - val byNameMsg = "Cannot find field 'foo.bar' in Protobuf schema" - assertFailedConversionMessage(protoFile, Serializer, BY_NAME, byNameMsg) - assertFailedConversionMessage(protoFile, Serializer, BY_NAME, byNameMsg, nonnullCatalyst) + // serialize fails whether or not 'bar' is nullable + val byNameMsg = "Cannot find field 'foo.bar' in Protobuf schema" + assertFailedConversionMessage(protoFile, Serializer, BY_NAME, byNameMsg) + assertFailedConversionMessage(protoFile, Serializer, BY_NAME, byNameMsg, nonnullCatalyst) + } } test("Fail to convert with deeply nested field type mismatch") { - val protoFile = ProtobufUtils.buildDescriptorFromJavaClass( - s"${javaClassNamePrefix}MissMatchTypeInDeepNested" - ) + val catalyst = new StructType().add("top", CATALYST_STRUCT) withFieldMatchType { fieldMatch => - assertFailedConversionMessage( - protoFile, - Deserializer, - fieldMatch, - s"Cannot convert Protobuf field 'top.foo.bar' to SQL field 'top.foo.bar' because schema " + - s"is incompatible (protoType = org.apache.spark.sql.protobuf.protos.TypeMiss.bar " + - s"LABEL_OPTIONAL LONG INT64, sqlType = INT)", - catalyst) - - assertFailedConversionMessage( - protoFile, - Serializer, - fieldMatch, - "Cannot convert SQL field 'top.foo.bar' to Protobuf field 'top.foo.bar' because schema " + - """is incompatible (sqlType = INT, protoType = LONG)""", - catalyst) + checkWithFileAndClassName("MissMatchTypeInDeepNested") { + case (name, descPathOpt) => + val protoFile = ProtobufUtils.buildDescriptor(name, descPathOpt) + + val version = descPathOpt match { + case Some(path) => if (path == descPathV2) 2 else 3 + case None => if (name.contains(".v2.")) 2 else 3 + } + + assertFailedConversionMessage( + protoFile, + Deserializer, + fieldMatch, + s"Cannot convert Protobuf field 'top.foo.bar' to SQL field 'top.foo.bar' " + + s"because schema is incompatible (protoType = " + + s"org.apache.spark.sql.protobuf.protos.v$version.TypeMiss.bar " + + s"LABEL_OPTIONAL LONG INT64, sqlType = INT)", + catalyst) + + assertFailedConversionMessage( + protoFile, + Serializer, + fieldMatch, + "Cannot convert SQL field 'top.foo.bar' to Protobuf field 'top.foo.bar' because " + + "schema is incompatible (sqlType = INT, protoType = LONG)", + catalyst) + } } } test("Fail to convert with missing Catalyst fields") { - val protoFile = ProtobufUtils.buildDescriptor(testFileDesc, "FieldMissingInSQLRoot") - - // serializing with extra fails if extra field is missing in SQL Schema - assertFailedConversionMessage( - protoFile, - Serializer, - BY_NAME, - "Found field 'boo' in Protobuf schema but there is no match in the SQL schema") - - /* deserializing should work regardless of whether the extra field is missing - in SQL Schema or not */ - withFieldMatchType(Deserializer.create(CATALYST_STRUCT, protoFile, _)) - withFieldMatchType(Deserializer.create(CATALYST_STRUCT, protoFile, _)) - - val protoNestedFile = ProtobufUtils.buildDescriptor(testFileDesc, "FieldMissingInSQLNested") - - // serializing with extra fails if extra field is missing in SQL Schema - assertFailedConversionMessage( - protoNestedFile, - Serializer, - BY_NAME, - "Found field 'foo.baz' in Protobuf schema but there is no match in the SQL schema") - - /* deserializing should work regardless of whether the extra field is missing + + checkWithFileAndClassName("FieldMissingInSQLRoot") { + case (name, descPathOpt) => + val protoFile = ProtobufUtils.buildDescriptor(name, descPathOpt) + // serializing with extra fails if extra field is missing in SQL Schema + assertFailedConversionMessage( + protoFile, + Serializer, + BY_NAME, + "Found field 'boo' in Protobuf schema but there is no match in the SQL schema") + + // deserializing should work regardless of whether the extra field is missing + // in SQL Schema or not */ + withFieldMatchType(Deserializer.create(CATALYST_STRUCT, protoFile, _)) + withFieldMatchType(Deserializer.create(CATALYST_STRUCT, protoFile, _)) + + val nestedMessageName = name.replace("FieldMissingInSQLRoot", "FieldMissingInSQLNested") + val protoNestedFile = ProtobufUtils.buildDescriptor(nestedMessageName, descPathOpt) + + // serializing with extra fails if extra field is missing in SQL Schema + assertFailedConversionMessage( + protoNestedFile, + Serializer, + BY_NAME, + "Found field 'foo.baz' in Protobuf schema but there is no match in the SQL schema") + + /* deserializing should work regardless of whether the extra field is missing in SQL Schema or not */ - withFieldMatchType(Deserializer.create(CATALYST_STRUCT, protoNestedFile, _)) - withFieldMatchType(Deserializer.create(CATALYST_STRUCT, protoNestedFile, _)) + withFieldMatchType(Deserializer.create(CATALYST_STRUCT, protoNestedFile, _)) + withFieldMatchType(Deserializer.create(CATALYST_STRUCT, protoNestedFile, _)) + } } /** diff --git a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufTestUtils.scala b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufTestUtils.scala new file mode 100644 index 000000000000..afc967a63e1e --- /dev/null +++ b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufTestUtils.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.protobuf + +import org.apache.spark.sql.test.SQLTestUtils + +/** + * A trait used by all the Protobuf function suites. + */ +private[sql] trait ProtobufTestUtils extends SQLTestUtils { + + val javaPackageV2 = "org.apache.spark.sql.protobuf.protos.v2" + val javaPackageV3 = "org.apache.spark.sql.protobuf.protos.v3" + + def javaClassFullNameV2(outerClass: String, innerClass: String): String = + s"$javaPackageV2.$outerClass$$$innerClass" + + def javaClassFullNameV3(outerClass: String, innerClass: String): String = + s"$javaPackageV3.$outerClass$$$innerClass" + + def descFilePathV2(descFileName: String): String = + testFile(s"descriptor-set-v2/$descFileName").replace("file:/", "/") + + def descFilePathV3(descFileName: String): String = + testFile(s"descriptor-set-v3/$descFileName").replace("file:/", "/") +} diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 33883a2efaa5..6577e9ef5e42 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -723,10 +723,11 @@ object SparkProtobuf { dependencyOverrides += "com.google.protobuf" % "protobuf-java" % protoVersion, - (Test / PB.protoSources) += (Test / sourceDirectory).value / "resources", - + // Protobuf V3 code is generated here. V2 is generated in another settings bellow. (Test / PB.targets) := Seq( - PB.gens.java -> target.value / "generated-test-sources" + PB.gens.java -> target.value / "generated-test-sources", + PB.gens.descriptorSet -> + target.value / "generated-test-sources/protobuf/descriptor-set.desc", ), (assembly / test) := { }, diff --git a/python/pyspark/sql/protobuf/functions.py b/python/pyspark/sql/protobuf/functions.py index 2059d868c7c1..49f741569894 100644 --- a/python/pyspark/sql/protobuf/functions.py +++ b/python/pyspark/sql/protobuf/functions.py @@ -101,7 +101,7 @@ def from_protobuf( assert sc is not None and sc._jvm is not None try: jc = sc._jvm.org.apache.spark.sql.protobuf.functions.from_protobuf( - _to_java_column(data), messageName, descFilePath, options or {} + _to_java_column(data), descFilePath, messageName, options or {} ) except TypeError as e: if str(e) == "'JavaPackage' object is not callable": diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index ae425419c540..7c55805a5a94 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -463,7 +463,10 @@ private[sql] trait SQLTestUtilsBase * Returns full path to the given file in the resource folder */ protected def testFile(fileName: String): String = { - Thread.currentThread().getContextClassLoader.getResource(fileName).toString + Option(Thread.currentThread().getContextClassLoader.getResource(fileName)) match { + case Some(path) => path.toString + case None => throw new IllegalArgumentException(s"Resource not found: $fileName") + } } /**