diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala index 0e61e38ff2b0..093047ea72f5 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala @@ -18,14 +18,19 @@ package org.apache.spark.api.python import java.io.File +import java.nio.file.Paths import java.util.{List => JList} +import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.jdk.CollectionConverters._ +import scala.sys.process.Process import org.apache.spark.{SparkContext, SparkEnv} import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.internal.Logging +import org.apache.spark.util.ArrayImplicits.SparkArrayOps +import org.apache.spark.util.Utils private[spark] object PythonUtils extends Logging { val PY4J_ZIP_NAME = "py4j-0.10.9.7-src.zip" @@ -113,11 +118,10 @@ private[spark] object PythonUtils extends Logging { } val pythonVersionCMD = Seq(pythonExec, "-VV") - val PYTHONPATH = "PYTHONPATH" val pythonPath = PythonUtils.mergePythonPaths( PythonUtils.sparkPythonPath, - sys.env.getOrElse(PYTHONPATH, "")) - val environment = Map(PYTHONPATH -> pythonPath) + sys.env.getOrElse("PYTHONPATH", "")) + val environment = Map("PYTHONPATH" -> pythonPath) logInfo(s"Python path $pythonPath") val processPythonVer = Process(pythonVersionCMD, None, environment.toSeq: _*) @@ -145,4 +149,48 @@ private[spark] object PythonUtils extends Logging { listOfPackages.foreach(x => logInfo(s"List of Python packages :- ${formatOutput(x)}")) } } + + // Only for testing. + private[spark] var additionalTestingPath: Option[String] = None + + private[spark] def createPythonFunction(command: Array[Byte]): SimplePythonFunction = { + val pythonExec: String = sys.env.getOrElse( + "PYSPARK_DRIVER_PYTHON", sys.env.getOrElse("PYSPARK_PYTHON", "python3")) + + val sourcePython = if (Utils.isTesting) { + // Put PySpark source code instead of the build zip archive so we don't need + // to build PySpark every time during development. + val sparkHome: String = { + require( + sys.props.contains("spark.test.home") || sys.env.contains("SPARK_HOME"), + "spark.test.home or SPARK_HOME is not set.") + sys.props.getOrElse("spark.test.home", sys.env("SPARK_HOME")) + } + val sourcePath = Paths.get(sparkHome, "python").toAbsolutePath + val py4jPath = Paths.get( + sparkHome, "python", "lib", PythonUtils.PY4J_ZIP_NAME).toAbsolutePath + val merged = mergePythonPaths(sourcePath.toString, py4jPath.toString) + // Adds a additional path to search Python packages for testing. + additionalTestingPath.map(mergePythonPaths(_, merged)).getOrElse(merged) + } else { + PythonUtils.sparkPythonPath + } + val pythonPath = PythonUtils.mergePythonPaths( + sourcePython, sys.env.getOrElse("PYTHONPATH", "")) + + val pythonVer: String = + Process( + Seq(pythonExec, "-c", "import sys; print('%d.%d' % sys.version_info[:2])"), + None, + "PYTHONPATH" -> pythonPath).!!.trim() + + SimplePythonFunction( + command = command.toImmutableArraySeq, + envVars = mutable.Map("PYTHONPATH" -> pythonPath).asJava, + pythonIncludes = List.empty.asJava, + pythonExec = pythonExec, + pythonVer = pythonVer, + broadcastVars = List.empty.asJava, + accumulator = null) + } } diff --git a/python/pyspark/sql/worker/lookup_data_sources.py b/python/pyspark/sql/worker/lookup_data_sources.py new file mode 100644 index 000000000000..91963658ee61 --- /dev/null +++ b/python/pyspark/sql/worker/lookup_data_sources.py @@ -0,0 +1,99 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from importlib import import_module +from pkgutil import iter_modules +import os +import sys +from typing import IO + +from pyspark.accumulators import _accumulatorRegistry +from pyspark.java_gateway import local_connect_and_auth +from pyspark.serializers import ( + read_int, + write_int, + write_with_length, + SpecialLengths, +) +from pyspark.sql.datasource import DataSource +from pyspark.util import handle_worker_exception +from pyspark.worker_util import ( + check_python_version, + pickleSer, + send_accumulator_updates, + setup_broadcasts, + setup_memory_limits, + setup_spark_files, +) + + +def main(infile: IO, outfile: IO) -> None: + """ + Main method for looking up the available Python Data Sources in Python path. + + This process is invoked from the `UserDefinedPythonDataSourceLookupRunner.runInPython` + method in `UserDefinedPythonDataSource.lookupAllDataSourcesInPython` when the first + call related to Python Data Source happens via `DataSourceManager`. + + This is responsible for searching the available Python Data Sources so they can be + statically registered automatically. + """ + try: + check_python_version(infile) + + memory_limit_mb = int(os.environ.get("PYSPARK_PLANNER_MEMORY_MB", "-1")) + setup_memory_limits(memory_limit_mb) + + setup_spark_files(infile) + setup_broadcasts(infile) + + _accumulatorRegistry.clear() + + infos = {} + for info in iter_modules(): + if info.name.startswith("pyspark_"): + mod = import_module(info.name) + if hasattr(mod, "DefaultSource") and issubclass(mod.DefaultSource, DataSource): + infos[mod.DefaultSource.name()] = mod.DefaultSource + + # Writes name -> pickled data source to JVM side to be registered + # as a Data Source. + write_int(len(infos), outfile) + for name, dataSource in infos.items(): + write_with_length(name.encode("utf-8"), outfile) + pickleSer._write_with_length(dataSource, outfile) + + except BaseException as e: + handle_worker_exception(e, outfile) + sys.exit(-1) + + send_accumulator_updates(outfile) + + # check end of stream + if read_int(infile) == SpecialLengths.END_OF_STREAM: + write_int(SpecialLengths.END_OF_STREAM, outfile) + else: + # write a different value to tell JVM to not reuse this worker + write_int(SpecialLengths.END_OF_DATA_SECTION, outfile) + sys.exit(-1) + + +if __name__ == "__main__": + # Read information about how to connect back to the JVM from the environment. + java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"]) + auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"] + (sock_file, _) = local_connect_and_auth(java_port, auth_secret) + main(sock_file, sock_file) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala index e6c4749df60a..c207645ce526 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala @@ -20,6 +20,9 @@ package org.apache.spark.sql.execution.datasources import java.util.Locale import java.util.concurrent.ConcurrentHashMap +import scala.jdk.CollectionConverters._ + +import org.apache.spark.api.python.PythonUtils import org.apache.spark.internal.Logging import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.python.UserDefinedPythonDataSource @@ -30,9 +33,13 @@ import org.apache.spark.sql.execution.python.UserDefinedPythonDataSource * their short names or fully qualified names. */ class DataSourceManager extends Logging { - // TODO(SPARK-45917): Statically load Python Data Source so idempotently Python - // Data Sources can be loaded even when the Driver is restarted. - private val dataSourceBuilders = new ConcurrentHashMap[String, UserDefinedPythonDataSource]() + // Lazy to avoid being invoked during Session initialization. + // Otherwise, it goes infinite loop, session -> Python runner -> SQLConf -> session. + private lazy val dataSourceBuilders = { + val builders = new ConcurrentHashMap[String, UserDefinedPythonDataSource]() + builders.putAll(DataSourceManager.initialDataSourceBuilders.asJava) + builders + } private def normalize(name: String): String = name.toLowerCase(Locale.ROOT) @@ -73,3 +80,20 @@ class DataSourceManager extends Logging { manager } } + + +object DataSourceManager { + // Visiable for testing + private[spark] var dataSourceBuilders: Option[Map[String, UserDefinedPythonDataSource]] = None + private def initialDataSourceBuilders = this.synchronized { + if (dataSourceBuilders.isEmpty) { + val result = UserDefinedPythonDataSource.lookupAllDataSourcesInPython() + val builders = result.names.zip(result.dataSources).map { case (name, dataSource) => + name -> + UserDefinedPythonDataSource(PythonUtils.createPythonFunction(dataSource)) + }.toMap + dataSourceBuilders = Some(builders) + } + dataSourceBuilders.get + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala index 778f55595aee..5f66210ad9c6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala @@ -25,7 +25,7 @@ import scala.jdk.CollectionConverters._ import net.razorvine.pickle.Pickler import org.apache.spark.{JobArtifactSet, SparkException} -import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType, PythonFunction, PythonWorkerUtils, SimplePythonFunction, SpecialLengths} +import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType, PythonFunction, PythonUtils, PythonWorkerUtils, SimplePythonFunction, SpecialLengths} import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.PythonUDF @@ -404,6 +404,59 @@ object UserDefinedPythonDataSource { * The schema of the output to the Python data source write function. */ val writeOutputSchema: StructType = new StructType().add("message", BinaryType) + + /** + * (Driver-side) Look up all available Python Data Sources. + */ + def lookupAllDataSourcesInPython(): PythonLookupAllDataSourcesResult = { + new UserDefinedPythonDataSourceLookupRunner( + PythonUtils.createPythonFunction(Array.empty[Byte])).runInPython() + } +} + +/** + * All Data Sources in Python + */ +case class PythonLookupAllDataSourcesResult( + names: Array[String], dataSources: Array[Array[Byte]]) + +/** + * A runner used to look up Python Data Sources available in Python path. + */ +class UserDefinedPythonDataSourceLookupRunner(lookupSources: PythonFunction) + extends PythonPlannerRunner[PythonLookupAllDataSourcesResult](lookupSources) { + + override val workerModule = "pyspark.sql.worker.lookup_data_sources" + + override protected def writeToPython(dataOut: DataOutputStream, pickler: Pickler): Unit = { + // No input needed. + } + + override protected def receiveFromPython( + dataIn: DataInputStream): PythonLookupAllDataSourcesResult = { + // Receive the pickled data source or an exception raised in Python worker. + val length = dataIn.readInt() + if (length == SpecialLengths.PYTHON_EXCEPTION_THROWN) { + val msg = PythonWorkerUtils.readUTF(dataIn) + throw QueryCompilationErrors.failToPlanDataSourceError( + action = "lookup", tpe = "instance", msg = msg) + } + + val shortNames = ArrayBuffer.empty[String] + val pickledDataSources = ArrayBuffer.empty[Array[Byte]] + val numDataSources = length + + for (_ <- 0 until numDataSources) { + val shortName = PythonWorkerUtils.readUTF(dataIn) + val pickledDataSource: Array[Byte] = PythonWorkerUtils.readBytes(dataIn) + shortNames.append(shortName) + pickledDataSources.append(pickledDataSource) + } + + PythonLookupAllDataSourcesResult( + names = shortNames.toArray, + dataSources = pickledDataSources.toArray) + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala index c261f1d529fd..45ee472ee638 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala @@ -17,11 +17,16 @@ package org.apache.spark.sql.execution.python +import java.io.{File, FileWriter} + import org.apache.spark.SparkException +import org.apache.spark.api.python.PythonUtils import org.apache.spark.sql.{AnalysisException, IntegratedUDFTestUtils, QueryTest, Row} +import org.apache.spark.sql.execution.datasources.DataSourceManager import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, DataSourceV2ScanRelation} import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.StructType +import org.apache.spark.util.Utils class PythonDataSourceSuite extends QueryTest with SharedSparkSession { import IntegratedUDFTestUtils._ @@ -29,7 +34,7 @@ class PythonDataSourceSuite extends QueryTest with SharedSparkSession { setupTestData() private def dataSourceName = "SimpleDataSource" - private def simpleDataSourceReaderScript: String = + private val simpleDataSourceReaderScript: String = """ |from pyspark.sql.datasource import DataSourceReader, InputPartition |class SimpleDataSourceReader(DataSourceReader): @@ -40,6 +45,56 @@ class PythonDataSourceSuite extends QueryTest with SharedSparkSession { | yield (1, partition.value) | yield (2, partition.value) |""".stripMargin + private val staticSourceName = "custom_source" + private var tempDir: File = _ + + override def beforeAll(): Unit = { + // Create a Python Data Source package before starting up the Spark Session + // that triggers automatic registration of the Python Data Source. + val dataSourceScript = + s""" + |from pyspark.sql.datasource import DataSource, DataSourceReader + |$simpleDataSourceReaderScript + | + |class DefaultSource(DataSource): + | def schema(self) -> str: + | return "id INT, partition INT" + | + | def reader(self, schema): + | return SimpleDataSourceReader() + | + | @classmethod + | def name(cls): + | return "$staticSourceName" + |""".stripMargin + tempDir = Utils.createTempDir() + // Write a temporary package to test. + // tmp/my_source + // tmp/my_source/__init__.py + val packageDir = new File(tempDir, "pyspark_mysource") + assert(packageDir.mkdir()) + Utils.tryWithResource( + new FileWriter(new File(packageDir, "__init__.py")))(_.write(dataSourceScript)) + // So Spark Session initialization can lookup this temporary directory. + DataSourceManager.dataSourceBuilders = None + PythonUtils.additionalTestingPath = Some(tempDir.toString) + super.beforeAll() + } + + override def afterAll(): Unit = { + try { + Utils.deleteRecursively(tempDir) + PythonUtils.additionalTestingPath = None + } finally { + super.afterAll() + } + } + + test("SPARK-45917: automatic registration of Python Data Source") { + assume(shouldTestPandasUDFs) + val df = spark.read.format(staticSourceName).load() + checkAnswer(df, Seq(Row(0, 0), Row(0, 1), Row(1, 0), Row(1, 1), Row(2, 0), Row(2, 1))) + } test("simple data source") { assume(shouldTestPandasUDFs)