diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala index 4916ff1f5974..801c8cb0263e 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala @@ -3232,14 +3232,15 @@ class PlanGenerationTestSuite "connect/common/src/test/resources/protobuf-tests/common.desc" test("from_protobuf messageClassName") { - binary.select(pbFn.from_protobuf(fn.col("bytes"), classOf[StorageLevel].getName)) + binary.select( + pbFn.from_protobuf(fn.col("bytes"), "org.apache.spark.sql.protobuf.protos.TestProtoObj")) } test("from_protobuf messageClassName options") { binary.select( pbFn.from_protobuf( fn.col("bytes"), - classOf[StorageLevel].getName, + "org.apache.spark.sql.protobuf.protos.TestProtoObj", Map("recursive.fields.max.depth" -> "2").asJava)) } diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/from_protobuf_messageClassName.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/from_protobuf_messageClassName.explain index e7a1867fe907..6f48cb090cde 100644 --- a/connector/connect/common/src/test/resources/query-tests/explain-results/from_protobuf_messageClassName.explain +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/from_protobuf_messageClassName.explain @@ -1,2 +1,2 @@ -Project [from_protobuf(bytes#0, org.apache.spark.connect.proto.StorageLevel, None) AS from_protobuf(bytes)#0] +Project [from_protobuf(bytes#0, org.apache.spark.sql.protobuf.protos.TestProtoObj, None) AS from_protobuf(bytes)#0] +- LocalRelation , [id#0L, bytes#0] diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/from_protobuf_messageClassName_options.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/from_protobuf_messageClassName_options.explain index c02d829fcac1..ba87e4774f1a 100644 --- a/connector/connect/common/src/test/resources/query-tests/explain-results/from_protobuf_messageClassName_options.explain +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/from_protobuf_messageClassName_options.explain @@ -1,2 +1,2 @@ -Project [from_protobuf(bytes#0, org.apache.spark.connect.proto.StorageLevel, None, (recursive.fields.max.depth,2)) AS from_protobuf(bytes)#0] +Project [from_protobuf(bytes#0, org.apache.spark.sql.protobuf.protos.TestProtoObj, None, (recursive.fields.max.depth,2)) AS from_protobuf(bytes)#0] +- LocalRelation , [id#0L, bytes#0] diff --git a/connector/connect/common/src/test/resources/query-tests/queries/from_protobuf_messageClassName.json b/connector/connect/common/src/test/resources/query-tests/queries/from_protobuf_messageClassName.json index dc23ac2a117b..6c5891e70165 100644 --- a/connector/connect/common/src/test/resources/query-tests/queries/from_protobuf_messageClassName.json +++ b/connector/connect/common/src/test/resources/query-tests/queries/from_protobuf_messageClassName.json @@ -20,7 +20,7 @@ } }, { "literal": { - "string": "org.apache.spark.connect.proto.StorageLevel" + "string": "org.apache.spark.sql.protobuf.protos.TestProtoObj" } }] } diff --git a/connector/connect/common/src/test/resources/query-tests/queries/from_protobuf_messageClassName.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/from_protobuf_messageClassName.proto.bin index cc46234b7476..9d7aeaf30896 100644 Binary files a/connector/connect/common/src/test/resources/query-tests/queries/from_protobuf_messageClassName.proto.bin and b/connector/connect/common/src/test/resources/query-tests/queries/from_protobuf_messageClassName.proto.bin differ diff --git a/connector/connect/common/src/test/resources/query-tests/queries/from_protobuf_messageClassName_options.json b/connector/connect/common/src/test/resources/query-tests/queries/from_protobuf_messageClassName_options.json index 36f69646ef83..691e144eabd0 100644 --- a/connector/connect/common/src/test/resources/query-tests/queries/from_protobuf_messageClassName_options.json +++ b/connector/connect/common/src/test/resources/query-tests/queries/from_protobuf_messageClassName_options.json @@ -20,7 +20,7 @@ } }, { "literal": { - "string": "org.apache.spark.connect.proto.StorageLevel" + "string": "org.apache.spark.sql.protobuf.protos.TestProtoObj" } }, { "unresolvedFunction": { diff --git a/connector/connect/common/src/test/resources/query-tests/queries/from_protobuf_messageClassName_options.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/from_protobuf_messageClassName_options.proto.bin index 72a1c6b8207e..89000e473fac 100644 Binary files a/connector/connect/common/src/test/resources/query-tests/queries/from_protobuf_messageClassName_options.proto.bin and b/connector/connect/common/src/test/resources/query-tests/queries/from_protobuf_messageClassName_options.proto.bin differ diff --git a/connector/connect/server/pom.xml b/connector/connect/server/pom.xml index e98b8da8e5c0..34c919bc6141 100644 --- a/connector/connect/server/pom.xml +++ b/connector/connect/server/pom.xml @@ -248,6 +248,13 @@ + + + kr.motd.maven + os-maven-plugin + 1.6.2 + + target/scala-${scala.binary.version}/classes target/scala-${scala.binary.version}/test-classes @@ -403,6 +410,87 @@ + + + org.apache.maven.plugins + maven-antrun-plugin + + + process-test-sources + + + + + + + run + + + + + + + default-protoc + + true + + + + + + org.xolstice.maven.plugins + protobuf-maven-plugin + 0.6.1 + + com.google.protobuf:protoc:${protobuf.version}:exe:${os.detected.classifier} + src/test/protobuf + + + + + test-compile + + + + + + + + + user-defined-protoc + + ${env.SPARK_PROTOC_EXEC_PATH} + + + + + org.xolstice.maven.plugins + protobuf-maven-plugin + 0.6.1 + + ${spark.protoc.executable.path} + src/test/protobuf + + + + + test-compile + + + + + + + + diff --git a/connector/connect/server/src/test/protobuf/test.proto b/connector/connect/server/src/test/protobuf/test.proto new file mode 100644 index 000000000000..844f89ba81f4 --- /dev/null +++ b/connector/connect/server/src/test/protobuf/test.proto @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto3"; +package org.apache.spark.sql.protobuf.protos; + +option java_multiple_files = true; +option java_outer_classname = "TestProto"; + +message TestProtoObj { + int64 v1 = 1; + int32 v2 = 2; +} diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 2f437eeb75cc..978165ba0dab 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -17,7 +17,7 @@ import java.io._ import java.nio.charset.StandardCharsets.UTF_8 -import java.nio.file.Files +import java.nio.file.{Files, StandardOpenOption} import java.util.Locale import scala.io.Source @@ -765,7 +765,13 @@ object SparkConnectCommon { object SparkConnect { import BuildCommons.protoVersion + val rewriteJavaFile = TaskKey[Unit]("rewriteJavaFile", + "Rewrite the generated Java PB files.") + val genPBAndRewriteJavaFile = TaskKey[Unit]("genPBAndRewriteJavaFile", + "Generate Java PB files and overwrite their contents.") + lazy val settings = Seq( + PB.protocVersion := BuildCommons.protoVersion, // For some reason the resolution from the imported Maven build does not work for some // of these dependendencies that we need to shade later on. libraryDependencies ++= { @@ -796,6 +802,42 @@ object SparkConnect { ) }, + // SPARK-43646: The following 3 statements are used to make the connect module use the + // Spark-proto assembly jar when compiling and testing using SBT, which can be keep same + // behavior of Maven testing. + (Test / unmanagedJars) += (LocalProject("protobuf") / assembly).value, + (Test / fullClasspath) := + (Test / fullClasspath).value.filterNot { f => f.toString.contains("spark-protobuf") }, + (Test / fullClasspath) += (LocalProject("protobuf") / assembly).value, + + (Test / PB.protoSources) += (Test / sourceDirectory).value / "resources" / "protobuf", + + (Test / PB.targets) := Seq( + PB.gens.java -> target.value / "generated-test-sources", + ), + + // SPARK-43646: Create a custom task to replace all `com.google.protobuf.` with + // `org.sparkproject.spark_protobuf.protobuf.` in the generated Java PB files. + // This is to generate Java files that can be used to test spark-protobuf functions + // in `ProtoToParsedPlanTestSuite`. + rewriteJavaFile := { + val protobufDir = target.value / "generated-test-sources"/"org"/"apache"/"spark"/"sql"/"protobuf"/"protos" + protobufDir.listFiles().foreach { f => + if (f.getName.endsWith(".java")) { + val contents = Files.readAllLines(f.toPath, UTF_8) + val replaced = contents.asScala.map { line => + line.replaceAll("com.google.protobuf.", "org.sparkproject.spark_protobuf.protobuf.") + } + Files.write(f.toPath, replaced.asJava, StandardOpenOption.TRUNCATE_EXISTING, StandardOpenOption.WRITE) + } + } + }, + // SPARK-43646: `genPBAndRewriteJavaFile` is used to specify the execution order of `PB.generate` + // and `rewriteJavaFile`, and makes `Test / compile` dependent on `genPBAndRewriteJavaFile` + // being executed first. + genPBAndRewriteJavaFile := Def.sequential(Test / PB.generate, rewriteJavaFile).value, + (Test / compile) := (Test / compile).dependsOn(genPBAndRewriteJavaFile).value, + (assembly / test) := { }, (assembly / logLevel) := Level.Info, @@ -841,7 +883,16 @@ object SparkConnect { case m if m.toLowerCase(Locale.ROOT).endsWith(".proto") => MergeStrategy.discard case _ => MergeStrategy.first } - ) + ) ++ { + val sparkProtocExecPath = sys.props.get("spark.protoc.executable.path") + if (sparkProtocExecPath.isDefined) { + Seq( + PB.protocExecutable := file(sparkProtocExecPath.get) + ) + } else { + Seq.empty + } + } } object SparkConnectClient {