Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 51 additions & 3 deletions core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,19 @@
package org.apache.spark.api.python

import java.io.File
import java.nio.file.Paths
import java.util.{List => JList}

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.jdk.CollectionConverters._
import scala.sys.process.Process

import org.apache.spark.{SparkContext, SparkEnv}
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
import org.apache.spark.internal.Logging
import org.apache.spark.util.ArrayImplicits.SparkArrayOps
import org.apache.spark.util.Utils

private[spark] object PythonUtils extends Logging {
val PY4J_ZIP_NAME = "py4j-0.10.9.7-src.zip"
Expand Down Expand Up @@ -113,11 +118,10 @@ private[spark] object PythonUtils extends Logging {
}

val pythonVersionCMD = Seq(pythonExec, "-VV")
val PYTHONPATH = "PYTHONPATH"
val pythonPath = PythonUtils.mergePythonPaths(
PythonUtils.sparkPythonPath,
sys.env.getOrElse(PYTHONPATH, ""))
val environment = Map(PYTHONPATH -> pythonPath)
sys.env.getOrElse("PYTHONPATH", ""))
val environment = Map("PYTHONPATH" -> pythonPath)
logInfo(s"Python path $pythonPath")

val processPythonVer = Process(pythonVersionCMD, None, environment.toSeq: _*)
Expand Down Expand Up @@ -145,4 +149,48 @@ private[spark] object PythonUtils extends Logging {
listOfPackages.foreach(x => logInfo(s"List of Python packages :- ${formatOutput(x)}"))
}
}

// Only for testing.
private[spark] var additionalTestingPath: Option[String] = None

private[spark] def createPythonFunction(command: Array[Byte]): SimplePythonFunction = {
val pythonExec: String = sys.env.getOrElse(
"PYSPARK_DRIVER_PYTHON", sys.env.getOrElse("PYSPARK_PYTHON", "python3"))

val sourcePython = if (Utils.isTesting) {
// Put PySpark source code instead of the build zip archive so we don't need
// to build PySpark every time during development.
val sparkHome: String = {
require(
sys.props.contains("spark.test.home") || sys.env.contains("SPARK_HOME"),
"spark.test.home or SPARK_HOME is not set.")
sys.props.getOrElse("spark.test.home", sys.env("SPARK_HOME"))
}
val sourcePath = Paths.get(sparkHome, "python").toAbsolutePath
val py4jPath = Paths.get(
sparkHome, "python", "lib", PythonUtils.PY4J_ZIP_NAME).toAbsolutePath
Comment on lines +170 to +171
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need Py4J path? The Python functions are not supposed to use Py4J?

Copy link
Member Author

@HyukjinKwon HyukjinKwon Dec 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

val merged = mergePythonPaths(sourcePath.toString, py4jPath.toString)
// Adds a additional path to search Python packages for testing.
additionalTestingPath.map(mergePythonPaths(_, merged)).getOrElse(merged)
} else {
PythonUtils.sparkPythonPath
}
val pythonPath = PythonUtils.mergePythonPaths(
sourcePython, sys.env.getOrElse("PYTHONPATH", ""))

val pythonVer: String =
Process(
Seq(pythonExec, "-c", "import sys; print('%d.%d' % sys.version_info[:2])"),
None,
"PYTHONPATH" -> pythonPath).!!.trim()

SimplePythonFunction(
command = command.toImmutableArraySeq,
envVars = mutable.Map("PYTHONPATH" -> pythonPath).asJava,
pythonIncludes = List.empty.asJava,
pythonExec = pythonExec,
pythonVer = pythonVer,
broadcastVars = List.empty.asJava,
accumulator = null)
}
}
99 changes: 99 additions & 0 deletions python/pyspark/sql/worker/lookup_data_sources.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from importlib import import_module
from pkgutil import iter_modules
import os
import sys
from typing import IO

from pyspark.accumulators import _accumulatorRegistry
from pyspark.java_gateway import local_connect_and_auth
from pyspark.serializers import (
read_int,
write_int,
write_with_length,
SpecialLengths,
)
from pyspark.sql.datasource import DataSource
from pyspark.util import handle_worker_exception
from pyspark.worker_util import (
check_python_version,
pickleSer,
send_accumulator_updates,
setup_broadcasts,
setup_memory_limits,
setup_spark_files,
)


def main(infile: IO, outfile: IO) -> None:
"""
Main method for looking up the available Python Data Sources in Python path.

This process is invoked from the `UserDefinedPythonDataSourceLookupRunner.runInPython`
method in `UserDefinedPythonDataSource.lookupAllDataSourcesInPython` when the first
call related to Python Data Source happens via `DataSourceManager`.

This is responsible for searching the available Python Data Sources so they can be
statically registered automatically.
"""
try:
check_python_version(infile)

memory_limit_mb = int(os.environ.get("PYSPARK_PLANNER_MEMORY_MB", "-1"))
setup_memory_limits(memory_limit_mb)

setup_spark_files(infile)
setup_broadcasts(infile)

_accumulatorRegistry.clear()

infos = {}
for info in iter_modules():
if info.name.startswith("pyspark_"):
mod = import_module(info.name)
if hasattr(mod, "DefaultSource") and issubclass(mod.DefaultSource, DataSource):
infos[mod.DefaultSource.name()] = mod.DefaultSource

# Writes name -> pickled data source to JVM side to be registered
# as a Data Source.
write_int(len(infos), outfile)
for name, dataSource in infos.items():
write_with_length(name.encode("utf-8"), outfile)
pickleSer._write_with_length(dataSource, outfile)

except BaseException as e:
handle_worker_exception(e, outfile)
sys.exit(-1)

send_accumulator_updates(outfile)

# check end of stream
if read_int(infile) == SpecialLengths.END_OF_STREAM:
write_int(SpecialLengths.END_OF_STREAM, outfile)
else:
# write a different value to tell JVM to not reuse this worker
write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
sys.exit(-1)


if __name__ == "__main__":
# Read information about how to connect back to the JVM from the environment.
java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"])
auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"]
(sock_file, _) = local_connect_and_auth(java_port, auth_secret)
main(sock_file, sock_file)
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ package org.apache.spark.sql.execution.datasources
import java.util.Locale
import java.util.concurrent.ConcurrentHashMap

import scala.jdk.CollectionConverters._

import org.apache.spark.api.python.PythonUtils
import org.apache.spark.internal.Logging
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.python.UserDefinedPythonDataSource
Expand All @@ -30,9 +33,13 @@ import org.apache.spark.sql.execution.python.UserDefinedPythonDataSource
* their short names or fully qualified names.
*/
class DataSourceManager extends Logging {
// TODO(SPARK-45917): Statically load Python Data Source so idempotently Python
// Data Sources can be loaded even when the Driver is restarted.
private val dataSourceBuilders = new ConcurrentHashMap[String, UserDefinedPythonDataSource]()
// Lazy to avoid being invoked during Session initialization.
// Otherwise, it goes infinite loop, session -> Python runner -> SQLConf -> session.
private lazy val dataSourceBuilders = {
val builders = new ConcurrentHashMap[String, UserDefinedPythonDataSource]()
builders.putAll(DataSourceManager.initialDataSourceBuilders.asJava)
builders
}

private def normalize(name: String): String = name.toLowerCase(Locale.ROOT)

Expand Down Expand Up @@ -73,3 +80,20 @@ class DataSourceManager extends Logging {
manager
}
}


object DataSourceManager {
// Visiable for testing
private[spark] var dataSourceBuilders: Option[Map[String, UserDefinedPythonDataSource]] = None
private def initialDataSourceBuilders = this.synchronized {
if (dataSourceBuilders.isEmpty) {
val result = UserDefinedPythonDataSource.lookupAllDataSourcesInPython()
val builders = result.names.zip(result.dataSources).map { case (name, dataSource) =>
name ->
UserDefinedPythonDataSource(PythonUtils.createPythonFunction(dataSource))
}.toMap
dataSourceBuilders = Some(builders)
}
dataSourceBuilders.get
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import scala.jdk.CollectionConverters._
import net.razorvine.pickle.Pickler

import org.apache.spark.{JobArtifactSet, SparkException}
import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType, PythonFunction, PythonWorkerUtils, SimplePythonFunction, SpecialLengths}
import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType, PythonFunction, PythonUtils, PythonWorkerUtils, SimplePythonFunction, SpecialLengths}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.PythonUDF
Expand Down Expand Up @@ -404,6 +404,59 @@ object UserDefinedPythonDataSource {
* The schema of the output to the Python data source write function.
*/
val writeOutputSchema: StructType = new StructType().add("message", BinaryType)

/**
* (Driver-side) Look up all available Python Data Sources.
*/
def lookupAllDataSourcesInPython(): PythonLookupAllDataSourcesResult = {
new UserDefinedPythonDataSourceLookupRunner(
PythonUtils.createPythonFunction(Array.empty[Byte])).runInPython()
}
}

/**
* All Data Sources in Python
*/
case class PythonLookupAllDataSourcesResult(
names: Array[String], dataSources: Array[Array[Byte]])

/**
* A runner used to look up Python Data Sources available in Python path.
*/
class UserDefinedPythonDataSourceLookupRunner(lookupSources: PythonFunction)
extends PythonPlannerRunner[PythonLookupAllDataSourcesResult](lookupSources) {

override val workerModule = "pyspark.sql.worker.lookup_data_sources"

override protected def writeToPython(dataOut: DataOutputStream, pickler: Pickler): Unit = {
// No input needed.
}

override protected def receiveFromPython(
dataIn: DataInputStream): PythonLookupAllDataSourcesResult = {
// Receive the pickled data source or an exception raised in Python worker.
val length = dataIn.readInt()
if (length == SpecialLengths.PYTHON_EXCEPTION_THROWN) {
val msg = PythonWorkerUtils.readUTF(dataIn)
throw QueryCompilationErrors.failToPlanDataSourceError(
action = "lookup", tpe = "instance", msg = msg)
}

val shortNames = ArrayBuffer.empty[String]
val pickledDataSources = ArrayBuffer.empty[Array[Byte]]
val numDataSources = length

for (_ <- 0 until numDataSources) {
val shortName = PythonWorkerUtils.readUTF(dataIn)
val pickledDataSource: Array[Byte] = PythonWorkerUtils.readBytes(dataIn)
shortNames.append(shortName)
pickledDataSources.append(pickledDataSource)
}

PythonLookupAllDataSourcesResult(
names = shortNames.toArray,
dataSources = pickledDataSources.toArray)
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,24 @@

package org.apache.spark.sql.execution.python

import java.io.{File, FileWriter}

import org.apache.spark.SparkException
import org.apache.spark.api.python.PythonUtils
import org.apache.spark.sql.{AnalysisException, IntegratedUDFTestUtils, QueryTest, Row}
import org.apache.spark.sql.execution.datasources.DataSourceManager
import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, DataSourceV2ScanRelation}
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.Utils

class PythonDataSourceSuite extends QueryTest with SharedSparkSession {
import IntegratedUDFTestUtils._

setupTestData()

private def dataSourceName = "SimpleDataSource"
private def simpleDataSourceReaderScript: String =
private val simpleDataSourceReaderScript: String =
"""
|from pyspark.sql.datasource import DataSourceReader, InputPartition
|class SimpleDataSourceReader(DataSourceReader):
Expand All @@ -40,6 +45,56 @@ class PythonDataSourceSuite extends QueryTest with SharedSparkSession {
| yield (1, partition.value)
| yield (2, partition.value)
|""".stripMargin
private val staticSourceName = "custom_source"
private var tempDir: File = _

override def beforeAll(): Unit = {
// Create a Python Data Source package before starting up the Spark Session
// that triggers automatic registration of the Python Data Source.
val dataSourceScript =
s"""
|from pyspark.sql.datasource import DataSource, DataSourceReader
|$simpleDataSourceReaderScript
|
|class DefaultSource(DataSource):
| def schema(self) -> str:
| return "id INT, partition INT"
|
| def reader(self, schema):
| return SimpleDataSourceReader()
|
| @classmethod
| def name(cls):
| return "$staticSourceName"
|""".stripMargin
tempDir = Utils.createTempDir()
// Write a temporary package to test.
// tmp/my_source
// tmp/my_source/__init__.py
val packageDir = new File(tempDir, "pyspark_mysource")
assert(packageDir.mkdir())
Utils.tryWithResource(
new FileWriter(new File(packageDir, "__init__.py")))(_.write(dataSourceScript))
// So Spark Session initialization can lookup this temporary directory.
DataSourceManager.dataSourceBuilders = None
PythonUtils.additionalTestingPath = Some(tempDir.toString)
super.beforeAll()
}

override def afterAll(): Unit = {
try {
Utils.deleteRecursively(tempDir)
PythonUtils.additionalTestingPath = None
} finally {
super.afterAll()
}
}

test("SPARK-45917: automatic registration of Python Data Source") {
assume(shouldTestPandasUDFs)
val df = spark.read.format(staticSourceName).load()
checkAnswer(df, Seq(Row(0, 0), Row(0, 1), Row(1, 0), Row(1, 1), Row(2, 0), Row(2, 1)))
}

test("simple data source") {
assume(shouldTestPandasUDFs)
Expand Down