diff --git a/dev/lint-python b/dev/lint-python index e3a11530c2b4..8d587bd52aca 100755 --- a/dev/lint-python +++ b/dev/lint-python @@ -278,7 +278,7 @@ function black_test { fi echo "starting black test..." - BLACK_REPORT=$( ($BLACK_BUILD --config dev/pyproject.toml --check python/pyspark dev python/setup.py) 2>&1) + BLACK_REPORT=$( ($BLACK_BUILD --config dev/pyproject.toml --check python/pyspark dev python/packaging) 2>&1) BLACK_STATUS=$? if [ "$BLACK_STATUS" -ne 0 ]; then diff --git a/dev/make-distribution.sh b/dev/make-distribution.sh index 70684a02a8dd..d2492c72fcf7 100755 --- a/dev/make-distribution.sh +++ b/dev/make-distribution.sh @@ -248,7 +248,7 @@ if [ "$MAKE_PIP" == "true" ]; then pushd "$SPARK_HOME/python" > /dev/null # Delete the egg info file if it exists, this can cache older setup files. rm -rf pyspark.egg-info || echo "No existing egg info file, skipping deletion" - python3 setup.py sdist + python3 packaging/classic/setup.py sdist popd > /dev/null else echo "Skipping building python distribution package" diff --git a/dev/reformat-python b/dev/reformat-python index d2a56f18c397..46b7efc931aa 100755 --- a/dev/reformat-python +++ b/dev/reformat-python @@ -29,4 +29,4 @@ if [ $? -ne 0 ]; then exit 1 fi -$BLACK_BUILD --config dev/pyproject.toml python/pyspark dev python/setup.py +$BLACK_BUILD --config dev/pyproject.toml python/pyspark dev python/packaging diff --git a/dev/run-pip-tests b/dev/run-pip-tests index 773611d9d922..f8a547b0c917 100755 --- a/dev/run-pip-tests +++ b/dev/run-pip-tests @@ -73,7 +73,7 @@ PYSPARK_DIST="$FWDIR/python/dist/pyspark-$PYSPARK_VERSION.tar.gz" PIP_OPTIONS="--upgrade --no-cache-dir --force-reinstall" # Test both regular user and edit/dev install modes. PIP_COMMANDS=("pip install $PIP_OPTIONS $PYSPARK_DIST" - "pip install $PIP_OPTIONS -e python/") + "pip install $PIP_OPTIONS -e python/packaging/classic") # Jenkins has PySpark installed under user sitepackages shared for some reasons. # In this test, explicitly exclude user sitepackages to prevent side effects @@ -103,7 +103,7 @@ for python in "${PYTHON_EXECS[@]}"; do cd "$FWDIR"/python # Delete the egg info file if it exists, this can cache the setup file. rm -rf pyspark.egg-info || echo "No existing egg info file, skipping deletion" - python3 setup.py sdist + python3 packaging/classic/setup.py sdist echo "Installing dist into virtual env" @@ -125,7 +125,7 @@ for python in "${PYTHON_EXECS[@]}"; do echo "Run basic sanity check with import based" python3 "$FWDIR"/dev/pip-sanity-check.py echo "Run the tests for context.py" - python3 "$FWDIR"/python/pyspark/context.py + python3 "$FWDIR"/python/pyspark/core/context.py cd "$FWDIR" diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 7604e8f4b0a0..6b087436c687 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -430,12 +430,12 @@ def __hash__(self): source_file_regexes=["python/(?!pyspark/(ml|mllib|sql|streaming))"], python_test_goals=[ # doctests - "pyspark.rdd", - "pyspark.context", - "pyspark.conf", - "pyspark.broadcast", + "pyspark.core.rdd", + "pyspark.core.context", + "pyspark.core.conf", + "pyspark.core.broadcast", "pyspark.accumulators", - "pyspark.files", + "pyspark.core.files", "pyspark.serializers", "pyspark.profiler", "pyspark.shuffle", diff --git a/docs/building-spark.md b/docs/building-spark.md index 56efbc1a0110..d10dfc9434fe 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -216,7 +216,7 @@ For information about how to run individual tests, refer to the If you are building Spark for use in a Python environment and you wish to pip install it, you will first need to build the Spark JARs as described above. Then you can construct an sdist package suitable for setup.py and pip installable package. - cd python; python setup.py sdist + cd python; python packaging/classic/setup.py sdist **Note:** Due to packaging requirements you can not directly pip install from the Python directory, rather you must first build the sdist package as described above. diff --git a/docs/rdd-programming-guide.md b/docs/rdd-programming-guide.md index 7f69272cbeb0..f75bda0ffafb 100644 --- a/docs/rdd-programming-guide.md +++ b/docs/rdd-programming-guide.md @@ -1321,7 +1321,7 @@ method. The code below shows this: {% highlight python %} >>> broadcastVar = sc.broadcast([1, 2, 3]) - + >>> broadcastVar.value [1, 2, 3] diff --git a/examples/src/main/python/avro_inputformat.py b/examples/src/main/python/avro_inputformat.py index e303860b32a6..97475dde14ec 100644 --- a/examples/src/main/python/avro_inputformat.py +++ b/examples/src/main/python/avro_inputformat.py @@ -47,7 +47,7 @@ from typing import Any, Tuple from functools import reduce -from pyspark.rdd import RDD +from pyspark import RDD from pyspark.sql import SparkSession if __name__ == "__main__": diff --git a/examples/src/main/python/parquet_inputformat.py b/examples/src/main/python/parquet_inputformat.py index 380cb7bdefdb..f9933f1a78da 100644 --- a/examples/src/main/python/parquet_inputformat.py +++ b/examples/src/main/python/parquet_inputformat.py @@ -32,7 +32,7 @@ import sys from typing import Any, Tuple -from pyspark.rdd import RDD +from pyspark import RDD from pyspark.sql import SparkSession if __name__ == "__main__": diff --git a/examples/src/main/python/sort.py b/examples/src/main/python/sort.py index bd198817b03b..9ef2d5dbaff4 100755 --- a/examples/src/main/python/sort.py +++ b/examples/src/main/python/sort.py @@ -18,7 +18,7 @@ import sys from typing import Tuple -from pyspark.rdd import RDD +from pyspark import RDD from pyspark.sql import SparkSession diff --git a/examples/src/main/python/streaming/network_wordjoinsentiments.py b/examples/src/main/python/streaming/network_wordjoinsentiments.py index b3f2114b9e8c..fae40a77acaf 100644 --- a/examples/src/main/python/streaming/network_wordjoinsentiments.py +++ b/examples/src/main/python/streaming/network_wordjoinsentiments.py @@ -34,7 +34,7 @@ from typing import Tuple from pyspark import SparkContext -from pyspark.rdd import RDD +from pyspark import RDD from pyspark.streaming import DStream, StreamingContext diff --git a/examples/src/main/python/streaming/recoverable_network_wordcount.py b/examples/src/main/python/streaming/recoverable_network_wordcount.py index 9d3fe4c30ec6..147d3c646799 100644 --- a/examples/src/main/python/streaming/recoverable_network_wordcount.py +++ b/examples/src/main/python/streaming/recoverable_network_wordcount.py @@ -40,10 +40,7 @@ import sys from typing import List, Tuple -from pyspark import SparkContext -from pyspark.accumulators import Accumulator -from pyspark.broadcast import Broadcast -from pyspark.rdd import RDD +from pyspark import SparkContext, Accumulator, Broadcast, RDD from pyspark.streaming import StreamingContext diff --git a/examples/src/main/python/streaming/sql_network_wordcount.py b/examples/src/main/python/streaming/sql_network_wordcount.py index 9518cb70ba78..bba398c0d610 100644 --- a/examples/src/main/python/streaming/sql_network_wordcount.py +++ b/examples/src/main/python/streaming/sql_network_wordcount.py @@ -30,8 +30,7 @@ import sys import datetime -from pyspark import SparkConf, SparkContext -from pyspark.rdd import RDD +from pyspark import SparkConf, SparkContext, RDD from pyspark.streaming import StreamingContext from pyspark.sql import Row, SparkSession diff --git a/pom.xml b/pom.xml index ca949a05c81c..1e31e2bf2574 100644 --- a/pom.xml +++ b/pom.xml @@ -222,7 +222,8 @@ 72.1 15.0.2 3.0.0-M1 diff --git a/python/.gitignore b/python/.gitignore index 52128cf844a7..7967aebe9bf3 100644 --- a/python/.gitignore +++ b/python/.gitignore @@ -1,5 +1,8 @@ *.pyc docs/_build/ pyspark.egg-info +pyspark_connect.egg-info build/ dist/ +./setup.py +./setup.cfg diff --git a/python/setup.cfg b/python/packaging/classic/setup.cfg similarity index 100% rename from python/setup.cfg rename to python/packaging/classic/setup.cfg diff --git a/python/setup.py b/python/packaging/classic/setup.py similarity index 92% rename from python/setup.py rename to python/packaging/classic/setup.py index ec7240107d1b..5242f749d622 100755 --- a/python/setup.py +++ b/python/packaging/classic/setup.py @@ -24,6 +24,25 @@ from setuptools import setup from setuptools.command.install import install from shutil import copyfile, copytree, rmtree +from pathlib import Path + +if ( + # When we package, the parent diectory 'classic' dir + # (as we pip install -e python/packaging/classic) + os.getcwd() == str(Path(__file__).parent.absolute()) + and str(Path(__file__).parent.name) == "classic" +): + # For: + # - pip install -e python/packaging/classic + # It moves the current working directory to 'classic' + # - cd python/packaging/classic; python setup.py sdist + # + # For: + # - python packaging/classic/setup.py sdist, it does not + # execute this branch. + # + # Move to spark/python + os.chdir(Path(__file__).parent.parent.parent.absolute()) try: exec(open("pyspark/version.py").read()) @@ -58,7 +77,7 @@ ./build/mvn -DskipTests clean package Building the source dist is done in the Python directory: cd python - python setup.py sdist + python packaging/classic/setup.py sdist pip install dist/*.tar.gz""" # Figure out where the jars are we need to package with PySpark. @@ -129,7 +148,8 @@ def _supports_symlinks(): # If you are changing the versions here, please also change ./python/pyspark/sql/pandas/utils.py # For Arrow, you should also check ./pom.xml and ensure there are no breaking changes in the # binary format protocol with the Java version, see ARROW_HOME/format/* for specifications. -# Also don't forget to update python/docs/source/getting_started/install.rst. +# Also don't forget to update python/docs/source/getting_started/install.rst, and +# python/packaging/connect/setup.py _minimum_pandas_version = "1.4.4" _minimum_numpy_version = "1.21" _minimum_pyarrow_version = "4.0.0" @@ -184,8 +204,11 @@ def run(self): copyfile("pyspark/shell.py", "pyspark/python/pyspark/shell.py") if in_spark: - # Construct the symlink farm - this is necessary since we can't refer to the path above the - # package root and we need to copy the jars and scripts which are up above the python root. + copyfile("packaging/classic/setup.py", "setup.py") + copyfile("packaging/classic/setup.cfg", "setup.cfg") + # Construct the symlink farm - this is nein_sparkcessary since we can't refer to + # the path above the package root and we need to copy the jars and scripts which + # are up above the python root. if _supports_symlinks(): os.symlink(JARS_PATH, JARS_TARGET) os.symlink(SCRIPTS_PATH, SCRIPTS_TARGET) @@ -234,6 +257,7 @@ def run(self): url="https://github.com/apache/spark/tree/master/python", packages=[ "pyspark", + "pyspark.core", "pyspark.cloudpickle", "pyspark.mllib", "pyspark.mllib.linalg", @@ -352,6 +376,8 @@ def run(self): # We only cleanup the symlink farm if we were in Spark, otherwise we are installing rather than # packaging. if in_spark: + os.remove("setup.py") + os.remove("setup.cfg") # Depending on cleaning up the symlink farm or copied version if _supports_symlinks(): os.remove(os.path.join(TEMP_PATH, "jars")) diff --git a/python/packaging/connect/setup.cfg b/python/packaging/connect/setup.cfg new file mode 100644 index 000000000000..ed26209d369d --- /dev/null +++ b/python/packaging/connect/setup.cfg @@ -0,0 +1,22 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +[bdist_wheel] +universal = 1 + +[metadata] +description_file = README.md diff --git a/python/packaging/connect/setup.py b/python/packaging/connect/setup.py new file mode 100755 index 000000000000..f77074a1bb20 --- /dev/null +++ b/python/packaging/connect/setup.py @@ -0,0 +1,166 @@ +#!/usr/bin/env python3 + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# cd python +# python packaging/classic/setup.py sdist + +# cd python/packaging/classic +# python setup.py sdist + +import sys +from setuptools import setup +import os +from shutil import copyfile +import glob +from pathlib import Path + +if ( + # When we package, the parent diectory 'connect' dir + # (as we pip install -e python/packaging/connect) + os.getcwd() == str(Path(__file__).parent.absolute()) + and str(Path(__file__).parent.name) == "connect" +): + # For: + # - pip install -e python/packaging/connect + # It moves the current working directory to 'connect' + # - cd python/packaging/connect; python setup.py sdist + # + # For: + # - python packaging/connect/setup.py sdist, it does not + # execute this branch. + # + # Move to spark/python + os.chdir(Path(__file__).parent.parent.parent.absolute()) + +try: + exec(open("pyspark/version.py").read()) +except IOError: + print( + "Failed to load PySpark version file for packaging. You must be in Spark's python dir.", + file=sys.stderr, + ) + sys.exit(-1) +VERSION = __version__ # noqa + +# Check and see if we are under the spark path in which case we need to build the symlink farm. +# This is important because we only want to build the symlink farm while under Spark otherwise we +# want to use the symlink farm. And if the symlink farm exists under while under Spark (e.g. a +# partially built sdist) we should error and have the user sort it out. +in_spark = os.path.isfile("../core/src/main/scala/org/apache/spark/SparkContext.scala") or ( + os.path.isfile("../RELEASE") and len(glob.glob("../jars/spark*core*.jar")) == 1 +) + +try: + if in_spark: + copyfile("packaging/connect/setup.py", "setup.py") + copyfile("packaging/connect/setup.cfg", "setup.cfg") + + # If you are changing the versions here, please also change ./python/pyspark/sql/pandas/utils.py + # For Arrow, you should also check ./pom.xml and ensure there are no breaking changes in the + # binary format protocol with the Java version, see ARROW_HOME/format/* for specifications. + # Also don't forget to update python/docs/source/getting_started/install.rst, and + # python/packaging/classic/setup.py + _minimum_pandas_version = "1.4.4" + _minimum_numpy_version = "1.21" + _minimum_pyarrow_version = "4.0.0" + _minimum_grpc_version = "1.59.3" + _minimum_googleapis_common_protos_version = "1.56.4" + + with open("README.md") as f: + long_description = f.read() + + connect_packages = [ + "pyspark", + "pyspark.cloudpickle", + "pyspark.mllib", + "pyspark.mllib.linalg", + "pyspark.mllib.stat", + "pyspark.ml", + "pyspark.ml.connect", + "pyspark.ml.linalg", + "pyspark.ml.param", + "pyspark.ml.torch", + "pyspark.ml.deepspeed", + "pyspark.sql", + "pyspark.sql.avro", + "pyspark.sql.connect", + "pyspark.sql.connect.avro", + "pyspark.sql.connect.client", + "pyspark.sql.connect.functions", + "pyspark.sql.connect.proto", + "pyspark.sql.connect.streaming", + "pyspark.sql.connect.streaming.worker", + "pyspark.sql.functions", + "pyspark.sql.pandas", + "pyspark.sql.protobuf", + "pyspark.sql.streaming", + "pyspark.sql.worker", + "pyspark.streaming", + "pyspark.pandas", + "pyspark.pandas.data_type_ops", + "pyspark.pandas.indexes", + "pyspark.pandas.missing", + "pyspark.pandas.plot", + "pyspark.pandas.spark", + "pyspark.pandas.typedef", + "pyspark.pandas.usage_logging", + "pyspark.testing", + "pyspark.resource", + "pyspark.errors", + "pyspark.errors.exceptions", + ] + + setup( + name="pyspark-connect", + version=VERSION, + description="Python Spark Connect client for Apache Spark", + long_description=long_description, + long_description_content_type="text/markdown", + author="Spark Developers", + author_email="dev@spark.apache.org", + url="https://github.com/apache/spark/tree/master/python", + packages=connect_packages, + license="http://www.apache.org/licenses/LICENSE-2.0", + # Don't forget to update python/docs/source/getting_started/install.rst + # if you're updating the versions or dependencies. + install_requires=[ + "pandas>=%s" % _minimum_pandas_version, + "pyarrow>=%s" % _minimum_pyarrow_version, + "grpcio>=%s" % _minimum_grpc_version, + "grpcio-status>=%s" % _minimum_grpc_version, + "googleapis-common-protos>=%s" % _minimum_googleapis_common_protos_version, + "numpy>=%s" % _minimum_numpy_version, + ], + python_requires=">=3.8", + classifiers=[ + "Development Status :: 5 - Production/Stable", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: Implementation :: CPython", + "Programming Language :: Python :: Implementation :: PyPy", + "Typing :: Typed", + ], + ) +finally: + if in_spark: + os.remove("setup.py") + os.remove("setup.cfg") diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py index a28144aedd70..032da1857a87 100644 --- a/python/pyspark/__init__.py +++ b/python/pyspark/__init__.py @@ -46,17 +46,30 @@ A inheritable thread to use in Spark when the pinned thread mode is on. """ +import sys from functools import wraps from typing import cast, Any, Callable, TypeVar, Union -from pyspark.conf import SparkConf -from pyspark.rdd import RDD, RDDBarrier -from pyspark.files import SparkFiles -from pyspark.status import StatusTracker, SparkJobInfo, SparkStageInfo +from pyspark.util import is_remote_only + +if not is_remote_only(): + from pyspark.core.conf import SparkConf + from pyspark.core.rdd import RDD, RDDBarrier + from pyspark.core.files import SparkFiles + from pyspark.core.status import StatusTracker, SparkJobInfo, SparkStageInfo + from pyspark.core.broadcast import Broadcast + from pyspark.core import conf, rdd, files, status, broadcast + + # for backward compatibility references. + sys.modules["pyspark.conf"] = conf + sys.modules["pyspark.rdd"] = rdd + sys.modules["pyspark.files"] = files + sys.modules["pyspark.status"] = status + sys.modules["pyspark.broadcast"] = broadcast + from pyspark.util import InheritableThread, inheritable_thread_target from pyspark.storagelevel import StorageLevel from pyspark.accumulators import Accumulator, AccumulatorParam -from pyspark.broadcast import Broadcast from pyspark.serializers import MarshalSerializer, CPickleSerializer from pyspark.taskcontext import TaskContext, BarrierTaskContext, BarrierTaskInfo from pyspark.profiler import Profiler, BasicProfiler @@ -106,7 +119,12 @@ def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: # To avoid circular dependencies -from pyspark.context import SparkContext +if not is_remote_only(): + from pyspark.core.context import SparkContext + from pyspark.core import context + + # for backward compatibility references. + sys.modules["pyspark.context"] = context # for back compatibility from pyspark.sql import SQLContext, HiveContext, Row # noqa: F401 diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py index bf3d96b08515..d205720bd883 100644 --- a/python/pyspark/accumulators.py +++ b/python/pyspark/accumulators.py @@ -333,7 +333,7 @@ def _start_update_server(auth_token: str) -> AccumulatorServer: if __name__ == "__main__": import doctest - from pyspark.context import SparkContext + from pyspark.core.context import SparkContext globs = globals().copy() # The small batch size here ensures that we see multiple batches, diff --git a/python/pyspark/core/__init__.py b/python/pyspark/core/__init__.py new file mode 100644 index 000000000000..cce3acad34a4 --- /dev/null +++ b/python/pyspark/core/__init__.py @@ -0,0 +1,16 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/python/pyspark/broadcast.py b/python/pyspark/core/broadcast.py similarity index 97% rename from python/pyspark/broadcast.py rename to python/pyspark/core/broadcast.py index a5a68d779781..0b0d027b929e 100644 --- a/python/pyspark/broadcast.py +++ b/python/pyspark/core/broadcast.py @@ -37,9 +37,8 @@ Union, ) -from pyspark.java_gateway import local_connect_and_auth from pyspark.serializers import ChunkedStream, pickle_protocol -from pyspark.util import print_exec +from pyspark.util import print_exec, local_connect_and_auth from pyspark.errors import PySparkRuntimeError if TYPE_CHECKING: @@ -56,7 +55,7 @@ def _from_id(bid: int) -> "Broadcast[Any]": - from pyspark.broadcast import _broadcastRegistry + from pyspark.core.broadcast import _broadcastRegistry if bid not in _broadcastRegistry: raise PySparkRuntimeError( @@ -367,13 +366,13 @@ def clear(self) -> None: def _test() -> None: import doctest from pyspark.sql import SparkSession - import pyspark.broadcast + import pyspark.core.broadcast - globs = pyspark.broadcast.__dict__.copy() + globs = pyspark.core.broadcast.__dict__.copy() spark = SparkSession.builder.master("local[4]").appName("broadcast tests").getOrCreate() globs["spark"] = spark - (failure_count, test_count) = doctest.testmod(pyspark.broadcast, globs=globs) + (failure_count, test_count) = doctest.testmod(pyspark.core.broadcast, globs=globs) spark.stop() if failure_count: sys.exit(-1) diff --git a/python/pyspark/conf.py b/python/pyspark/core/conf.py similarity index 96% rename from python/pyspark/conf.py rename to python/pyspark/core/conf.py index ba43a506375a..fe7879c3501b 100644 --- a/python/pyspark/conf.py +++ b/python/pyspark/core/conf.py @@ -60,11 +60,11 @@ class SparkConf: Examples -------- - >>> from pyspark.conf import SparkConf - >>> from pyspark.context import SparkContext + >>> from pyspark.core.conf import SparkConf + >>> from pyspark.core.context import SparkContext >>> conf = SparkConf() >>> conf.setMaster("local").setAppName("My app") - + >>> conf.get("spark.master") 'local' >>> conf.get("spark.app.name") @@ -79,13 +79,13 @@ class SparkConf: >>> conf = SparkConf(loadDefaults=False) >>> conf.setSparkHome("/path") - + >>> conf.get("spark.home") '/path' >>> conf.setExecutorEnv("VAR1", "value1") - + >>> conf.setExecutorEnv(pairs = [("VAR3", "value3"), ("VAR4", "value4")]) - + >>> conf.get("spark.executorEnv.VAR1") 'value1' >>> print(conf.toDebugString()) @@ -124,7 +124,7 @@ def __init__( if _jconf: self._jconf = _jconf else: - from pyspark.context import SparkContext + from pyspark.core.context import SparkContext _jvm = _jvm or SparkContext._jvm diff --git a/python/pyspark/context.py b/python/pyspark/core/context.py similarity index 99% rename from python/pyspark/context.py rename to python/pyspark/core/context.py index bcc9fbf935ba..076119b959b1 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/core/context.py @@ -48,10 +48,10 @@ from pyspark import accumulators from pyspark.accumulators import Accumulator -from pyspark.broadcast import Broadcast, BroadcastPickleRegistry -from pyspark.conf import SparkConf -from pyspark.files import SparkFiles -from pyspark.java_gateway import launch_gateway, local_connect_and_auth +from pyspark.core.broadcast import Broadcast, BroadcastPickleRegistry +from pyspark.core.conf import SparkConf +from pyspark.core.files import SparkFiles +from pyspark.java_gateway import launch_gateway from pyspark.serializers import ( CPickleSerializer, BatchedSerializer, @@ -64,10 +64,11 @@ ) from pyspark.storagelevel import StorageLevel from pyspark.resource.information import ResourceInformation -from pyspark.rdd import RDD, _load_from_socket +from pyspark.core.rdd import RDD +from pyspark.util import _load_from_socket, local_connect_and_auth from pyspark.taskcontext import TaskContext from pyspark.traceback_utils import CallSite, first_spark_call -from pyspark.status import StatusTracker +from pyspark.core.status import StatusTracker from pyspark.profiler import ProfilerCollector, BasicProfiler, UDFBasicProfiler, MemoryProfiler from pyspark.errors import PySparkRuntimeError from py4j.java_gateway import is_instance_of, JavaGateway, JavaObject, JVMView @@ -144,7 +145,7 @@ class SparkContext: Examples -------- - >>> from pyspark.context import SparkContext + >>> from pyspark.core.context import SparkContext >>> sc = SparkContext('local', 'test') >>> sc2 = SparkContext('local', 'test2') # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): diff --git a/python/pyspark/files.py b/python/pyspark/core/files.py similarity index 98% rename from python/pyspark/files.py rename to python/pyspark/core/files.py index 92130389d975..83b98726aee7 100644 --- a/python/pyspark/files.py +++ b/python/pyspark/core/files.py @@ -135,7 +135,7 @@ def getRootDirectory(cls) -> str: Examples -------- - >>> from pyspark.files import SparkFiles + >>> from pyspark.core.files import SparkFiles >>> SparkFiles.getRootDirectory() # doctest: +SKIP '.../spark-a904728e-08d3-400c-a872-cfd82fd6dcd2/userFiles-648cf6d6-bb2c-4f53-82bd-e658aba0c5de' """ diff --git a/python/pyspark/rdd.py b/python/pyspark/core/rdd.py similarity index 96% rename from python/pyspark/rdd.py rename to python/pyspark/core/rdd.py index ff5ded8101ff..9a256c4ae14a 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/core/rdd.py @@ -51,7 +51,6 @@ TYPE_CHECKING, ) -from pyspark.java_gateway import local_connect_and_auth from pyspark.serializers import ( AutoBatchedSerializer, BatchedSerializer, @@ -62,8 +61,6 @@ CPickleSerializer, Serializer, pack_long, - read_int, - write_int, ) from pyspark.join import ( python_join, @@ -86,42 +83,28 @@ ExternalGroupBy, ) from pyspark.traceback_utils import SCCallSiteSync -from pyspark.util import fail_on_stopiteration, _parse_memory +from pyspark.util import ( + fail_on_stopiteration, + _parse_memory, + _load_from_socket, + _local_iterator_from_socket, +) from pyspark.errors import PySparkRuntimeError +# for backward compatibility references. +from pyspark.util import PythonEvalType # noqa: F401 -if TYPE_CHECKING: - import socket - import io +if TYPE_CHECKING: from py4j.java_gateway import JavaObject - from py4j.java_collections import JavaArray - from pyspark._typing import NonUDFType from pyspark._typing import S, NumberOrArray - from pyspark.context import SparkContext - from pyspark.sql.pandas._typing import ( - PandasScalarUDFType, - PandasGroupedMapUDFType, - PandasGroupedAggUDFType, - PandasWindowAggUDFType, - PandasScalarIterUDFType, - PandasMapIterUDFType, - PandasCogroupedMapUDFType, - ArrowMapIterUDFType, - PandasGroupedMapUDFWithStateType, - ArrowGroupedMapUDFType, - ArrowCogroupedMapUDFType, - ) + from pyspark.core.context import SparkContext from pyspark.sql.dataframe import DataFrame from pyspark.sql.types import AtomicType, StructType from pyspark.sql._typing import ( AtomicValue, RowLike, - SQLArrowBatchedUDFType, - SQLArrowTableUDFType, - SQLBatchedUDFType, - SQLTableUDFType, ) T = TypeVar("T") @@ -137,36 +120,6 @@ __all__ = ["RDD"] -class PythonEvalType: - """ - Evaluation type of python rdd. - - These values are internal to PySpark. - - These values should match values in org.apache.spark.api.python.PythonEvalType. - """ - - NON_UDF: "NonUDFType" = 0 - - SQL_BATCHED_UDF: "SQLBatchedUDFType" = 100 - SQL_ARROW_BATCHED_UDF: "SQLArrowBatchedUDFType" = 101 - - SQL_SCALAR_PANDAS_UDF: "PandasScalarUDFType" = 200 - SQL_GROUPED_MAP_PANDAS_UDF: "PandasGroupedMapUDFType" = 201 - SQL_GROUPED_AGG_PANDAS_UDF: "PandasGroupedAggUDFType" = 202 - SQL_WINDOW_AGG_PANDAS_UDF: "PandasWindowAggUDFType" = 203 - SQL_SCALAR_PANDAS_ITER_UDF: "PandasScalarIterUDFType" = 204 - SQL_MAP_PANDAS_ITER_UDF: "PandasMapIterUDFType" = 205 - SQL_COGROUPED_MAP_PANDAS_UDF: "PandasCogroupedMapUDFType" = 206 - SQL_MAP_ARROW_ITER_UDF: "ArrowMapIterUDFType" = 207 - SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE: "PandasGroupedMapUDFWithStateType" = 208 - SQL_GROUPED_MAP_ARROW_UDF: "ArrowGroupedMapUDFType" = 209 - SQL_COGROUPED_MAP_ARROW_UDF: "ArrowCogroupedMapUDFType" = 210 - - SQL_TABLE_UDF: "SQLTableUDFType" = 300 - SQL_ARROW_TABLE_UDF: "SQLArrowTableUDFType" = 301 - - def portable_hash(x: Hashable) -> int: """ This function returns consistent hash code for builtin types, especially @@ -226,100 +179,6 @@ def __new__(cls, mean: float, confidence: float, low: float, high: float) -> "Bo return obj -def _create_local_socket(sock_info: "JavaArray") -> "io.BufferedRWPair": - """ - Create a local socket that can be used to load deserialized data from the JVM - - Parameters - ---------- - sock_info : tuple - Tuple containing port number and authentication secret for a local socket. - - Returns - ------- - sockfile file descriptor of the local socket - """ - sockfile: "io.BufferedRWPair" - sock: "socket.socket" - port: int = sock_info[0] - auth_secret: str = sock_info[1] - sockfile, sock = local_connect_and_auth(port, auth_secret) - # The RDD materialization time is unpredictable, if we set a timeout for socket reading - # operation, it will very possibly fail. See SPARK-18281. - sock.settimeout(None) - return sockfile - - -def _load_from_socket(sock_info: "JavaArray", serializer: Serializer) -> Iterator[Any]: - """ - Connect to a local socket described by sock_info and use the given serializer to yield data - - Parameters - ---------- - sock_info : tuple - Tuple containing port number and authentication secret for a local socket. - serializer : class:`Serializer` - The PySpark serializer to use - - Returns - ------- - result of meth:`Serializer.load_stream`, - usually a generator that yields deserialized data - """ - sockfile = _create_local_socket(sock_info) - # The socket will be automatically closed when garbage-collected. - return serializer.load_stream(sockfile) - - -def _local_iterator_from_socket(sock_info: "JavaArray", serializer: Serializer) -> Iterator[Any]: - class PyLocalIterable: - """Create a synchronous local iterable over a socket""" - - def __init__(self, _sock_info: "JavaArray", _serializer: Serializer): - port: int - auth_secret: str - jsocket_auth_server: "JavaObject" - port, auth_secret, self.jsocket_auth_server = _sock_info - self._sockfile = _create_local_socket((port, auth_secret)) - self._serializer = _serializer - self._read_iter: Iterator[Any] = iter([]) # Initialize as empty iterator - self._read_status = 1 - - def __iter__(self) -> Iterator[Any]: - while self._read_status == 1: - # Request next partition data from Java - write_int(1, self._sockfile) - self._sockfile.flush() - - # If response is 1 then there is a partition to read, if 0 then fully consumed - self._read_status = read_int(self._sockfile) - if self._read_status == 1: - # Load the partition data as a stream and read each item - self._read_iter = self._serializer.load_stream(self._sockfile) - for item in self._read_iter: - yield item - - # An error occurred, join serving thread and raise any exceptions from the JVM - elif self._read_status == -1: - self.jsocket_auth_server.getResult() - - def __del__(self) -> None: - # If local iterator is not fully consumed, - if self._read_status == 1: - try: - # Finish consuming partition data stream - for _ in self._read_iter: - pass - # Tell Java to stop sending data and close connection - write_int(0, self._sockfile) - self._sockfile.flush() - except Exception: - # Ignore any errors, socket is automatically closed when garbage-collected - pass - - return iter(PyLocalIterable(sock_info, serializer)) - - class Partitioner: def __init__(self, numPartitions: int, partitionFunc: Callable[[Any], int]): self.numPartitions = numPartitions @@ -5343,7 +5202,7 @@ def mapPartitions( ... >>> barrier = rdd.barrier() >>> barrier - + >>> barrier.mapPartitions(f).collect() [3, 7] """ @@ -5396,7 +5255,7 @@ def mapPartitionsWithIndex( ... >>> barrier = rdd.barrier() >>> barrier - + >>> barrier.mapPartitionsWithIndex(f).sum() 6 """ @@ -5509,7 +5368,7 @@ def _is_barrier(self) -> bool: def _test() -> None: import doctest import tempfile - from pyspark.context import SparkContext + from pyspark.core.context import SparkContext tmp_dir = tempfile.TemporaryDirectory() globs = globals().copy() diff --git a/python/pyspark/status.py b/python/pyspark/core/status.py similarity index 100% rename from python/pyspark/status.py rename to python/pyspark/core/status.py diff --git a/python/pyspark/errors/exceptions/captured.py b/python/pyspark/errors/exceptions/captured.py index e78a7c7bce8a..e5ec257fb32e 100644 --- a/python/pyspark/errors/exceptions/captured.py +++ b/python/pyspark/errors/exceptions/captured.py @@ -15,13 +15,8 @@ # limitations under the License. # from contextlib import contextmanager -from typing import Any, Callable, Dict, Iterator, Optional, cast, List +from typing import Any, Callable, Dict, Iterator, Optional, cast, List, TYPE_CHECKING -import py4j -from py4j.protocol import Py4JJavaError -from py4j.java_gateway import is_instance_of, JavaObject - -from pyspark import SparkContext from pyspark.errors.exceptions.base import ( AnalysisException as BaseAnalysisException, IllegalArgumentException as BaseIllegalArgumentException, @@ -43,15 +38,22 @@ QueryContextType, ) +if TYPE_CHECKING: + from py4j.protocol import Py4JJavaError + from py4j.java_gateway import JavaObject + class CapturedException(PySparkException): def __init__( self, desc: Optional[str] = None, stackTrace: Optional[str] = None, - cause: Optional[Py4JJavaError] = None, - origin: Optional[Py4JJavaError] = None, + cause: Optional["Py4JJavaError"] = None, + origin: Optional["Py4JJavaError"] = None, ): + from pyspark import SparkContext + from py4j.protocol import Py4JJavaError + # desc & stackTrace vs origin are mutually exclusive. # cause is optional. assert (origin is not None and desc is None and stackTrace is None) or ( @@ -73,6 +75,8 @@ def __init__( self._origin = origin def __str__(self) -> str: + from pyspark import SparkContext + assert SparkContext._jvm is not None jvm = SparkContext._jvm @@ -91,6 +95,9 @@ def __str__(self) -> str: return str(desc) def getErrorClass(self) -> Optional[str]: + from pyspark import SparkContext + from py4j.java_gateway import is_instance_of + assert SparkContext._gateway is not None gw = SparkContext._gateway @@ -102,6 +109,9 @@ def getErrorClass(self) -> Optional[str]: return None def getMessageParameters(self) -> Optional[Dict[str, str]]: + from pyspark import SparkContext + from py4j.java_gateway import is_instance_of + assert SparkContext._gateway is not None gw = SparkContext._gateway @@ -113,6 +123,9 @@ def getMessageParameters(self) -> Optional[Dict[str, str]]: return None def getSqlState(self) -> Optional[str]: + from pyspark import SparkContext + from py4j.java_gateway import is_instance_of + assert SparkContext._gateway is not None gw = SparkContext._gateway if self._origin is not None and is_instance_of( @@ -123,6 +136,9 @@ def getSqlState(self) -> Optional[str]: return None def getMessage(self) -> str: + from pyspark import SparkContext + from py4j.java_gateway import is_instance_of + assert SparkContext._gateway is not None gw = SparkContext._gateway @@ -141,6 +157,9 @@ def getMessage(self) -> str: return "" def getQueryContext(self) -> List[BaseQueryContext]: + from pyspark import SparkContext + from py4j.java_gateway import is_instance_of + assert SparkContext._gateway is not None gw = SparkContext._gateway @@ -152,7 +171,10 @@ def getQueryContext(self) -> List[BaseQueryContext]: return [] -def convert_exception(e: Py4JJavaError) -> CapturedException: +def convert_exception(e: "Py4JJavaError") -> CapturedException: + from pyspark import SparkContext + from py4j.java_gateway import is_instance_of + assert e is not None assert SparkContext._jvm is not None assert SparkContext._gateway is not None @@ -189,7 +211,7 @@ def convert_exception(e: Py4JJavaError) -> CapturedException: elif is_instance_of(gw, e, "org.apache.spark.SparkNoSuchElementException"): return SparkNoSuchElementException(origin=e) - c: Py4JJavaError = e.getCause() + c: "Py4JJavaError" = e.getCause() stacktrace: str = jvm.org.apache.spark.util.Utils.exceptionString(e) if c is not None and ( is_instance_of(gw, c, "org.apache.spark.api.python.PythonException") @@ -211,6 +233,8 @@ def convert_exception(e: Py4JJavaError) -> CapturedException: def capture_sql_exception(f: Callable[..., Any]) -> Callable[..., Any]: def deco(*a: Any, **kw: Any) -> Any: + from py4j.protocol import Py4JJavaError + try: return f(*a, **kw) except Py4JJavaError as e: @@ -227,13 +251,17 @@ def deco(*a: Any, **kw: Any) -> Any: @contextmanager def unwrap_spark_exception() -> Iterator[Any]: + from pyspark import SparkContext + from py4j.protocol import Py4JJavaError + from py4j.java_gateway import is_instance_of + assert SparkContext._gateway is not None gw = SparkContext._gateway try: yield except Py4JJavaError as e: - je: Py4JJavaError = e.java_exception + je: "Py4JJavaError" = e.java_exception if je is not None and is_instance_of(gw, je, "org.apache.spark.SparkException"): converted = convert_exception(je.getCause()) if not isinstance(converted, UnknownException): @@ -252,6 +280,8 @@ def install_exception_handler() -> None: It's idempotent, could be called multiple times. """ + import py4j + original = py4j.protocol.get_return_value # The original `get_return_value` is not patched, it's idempotent. patched = capture_sql_exception(original) @@ -350,7 +380,7 @@ class UnknownException(CapturedException, BaseUnknownException): class QueryContext(BaseQueryContext): - def __init__(self, q: JavaObject): + def __init__(self, q: "JavaObject"): self._q = q def contextType(self) -> QueryContextType: diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index 39a90a0afbad..18b6536c7403 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -20,7 +20,6 @@ import signal import shlex import shutil -import socket import platform import tempfile import time @@ -28,11 +27,14 @@ from py4j.java_gateway import java_import, JavaGateway, JavaObject, GatewayParameters from py4j.clientserver import ClientServer, JavaParameters, PythonParameters +from pyspark.serializers import read_int, UTF8Deserializer from pyspark.find_spark_home import _find_spark_home -from pyspark.serializers import read_int, write_with_length, UTF8Deserializer from pyspark.errors import PySparkRuntimeError +# for backward compatibility references. +from pyspark.util import local_connect_and_auth # noqa: F401 + def launch_gateway(conf=None, popen_kwargs=None): """ @@ -167,65 +169,6 @@ def killChild(): return gateway -def _do_server_auth(conn, auth_secret): - """ - Performs the authentication protocol defined by the SocketAuthHelper class on the given - file-like object 'conn'. - """ - write_with_length(auth_secret.encode("utf-8"), conn) - conn.flush() - reply = UTF8Deserializer().loads(conn) - if reply != "ok": - conn.close() - raise PySparkRuntimeError( - error_class="UNEXPECTED_RESPONSE_FROM_SERVER", - message_parameters={}, - ) - - -def local_connect_and_auth(port, auth_secret): - """ - Connect to local host, authenticate with it, and return a (sockfile,sock) for that connection. - Handles IPV4 & IPV6, does some error handling. - - Parameters - ---------- - port : str or int or None - auth_secret : str - - Returns - ------- - tuple - with (sockfile, sock) - """ - sock = None - errors = [] - # Support for both IPv4 and IPv6. - addr = "127.0.0.1" - if os.environ.get("SPARK_PREFER_IPV6", "false").lower() == "true": - addr = "::1" - for res in socket.getaddrinfo(addr, port, socket.AF_UNSPEC, socket.SOCK_STREAM): - af, socktype, proto, _, sa = res - try: - sock = socket.socket(af, socktype, proto) - sock.settimeout(int(os.environ.get("SPARK_AUTH_SOCKET_TIMEOUT", 15))) - sock.connect(sa) - sockfile = sock.makefile("rwb", int(os.environ.get("SPARK_BUFFER_SIZE", 65536))) - _do_server_auth(sockfile, auth_secret) - return (sockfile, sock) - except socket.error as e: - emsg = str(e) - errors.append("tried to connect to %s, but an error occurred: %s" % (sa, emsg)) - sock.close() - sock = None - raise PySparkRuntimeError( - error_class="CANNOT_OPEN_SOCKET", - message_parameters={ - "errors": str(errors), - }, - ) - - def ensure_callback_server_started(gw): """ Start callback server if not already started. The callback server is needed if the Java diff --git a/python/pyspark/ml/common.py b/python/pyspark/ml/common.py index 91593515a2de..d9d7c9b9c21c 100644 --- a/python/pyspark/ml/common.py +++ b/python/pyspark/ml/common.py @@ -22,7 +22,7 @@ from py4j.java_gateway import JavaObject from py4j.java_collections import JavaArray, JavaList -import pyspark.context +import pyspark.core.context from pyspark import RDD, SparkContext from pyspark.serializers import CPickleSerializer, AutoBatchedSerializer from pyspark.sql import DataFrame, SparkSession @@ -122,7 +122,7 @@ def _java2py(sc: SparkContext, r: "JavaObjectOrPickleDump", encoding: str = "byt def callJavaFunc( - sc: pyspark.context.SparkContext, func: Callable[..., "JavaObjectOrPickleDump"], *args: Any + sc: pyspark.core.context.SparkContext, func: Callable[..., "JavaObjectOrPickleDump"], *args: Any ) -> "JavaObjectOrPickleDump": """Call Java Function""" java_args = [_py2java(sc, a) for a in args] diff --git a/python/pyspark/ml/torch/distributor.py b/python/pyspark/ml/torch/distributor.py index 6ac74c22b380..62a71c5a96af 100644 --- a/python/pyspark/ml/torch/distributor.py +++ b/python/pyspark/ml/torch/distributor.py @@ -847,7 +847,7 @@ def _setup_spark_partition_data( partition_data_iterator: Iterator[Any], input_schema_json: Dict[str, Any] ) -> Iterator[Any]: from pyspark.sql.pandas.serializers import ArrowStreamSerializer - from pyspark.files import SparkFiles + from pyspark.core.files import SparkFiles import json if input_schema_json is None: diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index 71f42954decb..ff5d69b084e2 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -25,7 +25,7 @@ from numpy import array, random, tile from pyspark import SparkContext, since -from pyspark.rdd import RDD +from pyspark.core.rdd import RDD from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, callJavaFunc, _py2java, _java2py from pyspark.mllib.linalg import SparseVector, _convert_to_vector, DenseVector # noqa: F401 from pyspark.mllib.stat.distribution import MultivariateGaussian diff --git a/python/pyspark/mllib/common.py b/python/pyspark/mllib/common.py index c5e1a7e8580c..bfab55b8552c 100644 --- a/python/pyspark/mllib/common.py +++ b/python/pyspark/mllib/common.py @@ -25,7 +25,7 @@ from py4j.java_gateway import JavaObject from py4j.java_collections import JavaArray, JavaList -import pyspark.context +import pyspark.core.context from pyspark import RDD, SparkContext from pyspark.serializers import CPickleSerializer, AutoBatchedSerializer from pyspark.sql import DataFrame, SparkSession @@ -124,7 +124,7 @@ def _java2py(sc: SparkContext, r: "JavaObjectOrPickleDump", encoding: str = "byt def callJavaFunc( - sc: pyspark.context.SparkContext, func: Callable[..., "JavaObjectOrPickleDump"], *args: Any + sc: pyspark.core.context.SparkContext, func: Callable[..., "JavaObjectOrPickleDump"], *args: Any ) -> Any: """Call Java Function""" java_args = [_py2java(sc, a) for a in args] diff --git a/python/pyspark/mllib/evaluation.py b/python/pyspark/mllib/evaluation.py index df756f848429..dfcee167bea5 100644 --- a/python/pyspark/mllib/evaluation.py +++ b/python/pyspark/mllib/evaluation.py @@ -19,7 +19,7 @@ import sys from pyspark import since -from pyspark.rdd import RDD +from pyspark.core.rdd import RDD from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc from pyspark.mllib.linalg import Matrix from pyspark.sql import SQLContext diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py index 1e1975306116..24884f485337 100644 --- a/python/pyspark/mllib/feature.py +++ b/python/pyspark/mllib/feature.py @@ -25,11 +25,11 @@ from py4j.protocol import Py4JJavaError from pyspark import since -from pyspark.rdd import RDD +from pyspark.core.rdd import RDD from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper from pyspark.mllib.linalg import Vectors, _convert_to_vector from pyspark.mllib.util import JavaLoader, JavaSaveable -from pyspark.context import SparkContext +from pyspark.core.context import SparkContext from pyspark.mllib.linalg import Vector from pyspark.mllib.regression import LabeledPoint from py4j.java_collections import JavaMap diff --git a/python/pyspark/mllib/fpm.py b/python/pyspark/mllib/fpm.py index 4b26ca642296..19da87da3c27 100644 --- a/python/pyspark/mllib/fpm.py +++ b/python/pyspark/mllib/fpm.py @@ -22,7 +22,7 @@ from pyspark import since, SparkContext from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc from pyspark.mllib.util import JavaSaveable, JavaLoader, inherit_doc -from pyspark.rdd import RDD +from pyspark.core.rdd import RDD __all__ = ["FPGrowth", "FPGrowthModel", "PrefixSpan", "PrefixSpanModel"] diff --git a/python/pyspark/mllib/random.py b/python/pyspark/mllib/random.py index 1342148a61a5..80bbd717071d 100644 --- a/python/pyspark/mllib/random.py +++ b/python/pyspark/mllib/random.py @@ -26,8 +26,8 @@ import numpy as np from pyspark.mllib.common import callMLlibFunc -from pyspark.context import SparkContext -from pyspark.rdd import RDD +from pyspark.core.context import SparkContext +from pyspark.core.rdd import RDD from pyspark.mllib.linalg import Vector diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py index 7ff8fddf88d3..fa3d1a73f35a 100644 --- a/python/pyspark/mllib/recommendation.py +++ b/python/pyspark/mllib/recommendation.py @@ -20,7 +20,7 @@ from typing import Any, List, NamedTuple, Optional, Tuple, Type, Union from pyspark import SparkContext, since -from pyspark.rdd import RDD +from pyspark.core.rdd import RDD from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, inherit_doc from pyspark.mllib.util import JavaLoader, JavaSaveable from pyspark.sql import DataFrame diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index cac3294ade62..f1003327912d 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -37,8 +37,8 @@ from pyspark.mllib.common import callMLlibFunc, _py2java, _java2py, inherit_doc from pyspark.mllib.linalg import _convert_to_vector from pyspark.mllib.util import Saveable, Loader -from pyspark.rdd import RDD -from pyspark.context import SparkContext +from pyspark.core.rdd import RDD +from pyspark.core.context import SparkContext from pyspark.mllib.linalg import Vector if TYPE_CHECKING: diff --git a/python/pyspark/mllib/stat/KernelDensity.py b/python/pyspark/mllib/stat/KernelDensity.py index febf4fd53fd2..bb03d0ef1357 100644 --- a/python/pyspark/mllib/stat/KernelDensity.py +++ b/python/pyspark/mllib/stat/KernelDensity.py @@ -21,7 +21,7 @@ from numpy import ndarray from pyspark.mllib.common import callMLlibFunc -from pyspark.rdd import RDD +from pyspark.core.rdd import RDD class KernelDensity: diff --git a/python/pyspark/mllib/stat/_statistics.py b/python/pyspark/mllib/stat/_statistics.py index a784e0e31733..c638fb819506 100644 --- a/python/pyspark/mllib/stat/_statistics.py +++ b/python/pyspark/mllib/stat/_statistics.py @@ -21,7 +21,7 @@ from numpy import ndarray from py4j.java_gateway import JavaObject -from pyspark.rdd import RDD +from pyspark.core.rdd import RDD from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper from pyspark.mllib.linalg import Matrix, Vector, _convert_to_vector from pyspark.mllib.regression import LabeledPoint diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py index 8a5c25d96a74..b24bced3ced6 100644 --- a/python/pyspark/mllib/tree.py +++ b/python/pyspark/mllib/tree.py @@ -24,7 +24,7 @@ from pyspark.mllib.regression import LabeledPoint from pyspark.mllib.util import JavaLoader, JavaSaveable from typing import Dict, Optional, Tuple, Union, overload, TYPE_CHECKING -from pyspark.rdd import RDD +from pyspark.core.rdd import RDD if TYPE_CHECKING: from pyspark.mllib._typing import VectorLike diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py index c5fb7f39c526..5572d9ca8555 100644 --- a/python/pyspark/mllib/util.py +++ b/python/pyspark/mllib/util.py @@ -25,9 +25,9 @@ from pyspark.mllib.linalg import Vectors, SparseVector, _convert_to_vector from pyspark.sql import DataFrame from typing import Generic, Iterable, List, Optional, Tuple, Type, TypeVar, cast, TYPE_CHECKING -from pyspark.context import SparkContext +from pyspark.core.context import SparkContext from pyspark.mllib.linalg import Vector -from pyspark.rdd import RDD +from pyspark.core.rdd import RDD from pyspark.sql.dataframe import DataFrame T = TypeVar("T") diff --git a/python/pyspark/profiler.py b/python/pyspark/profiler.py index 37605a4a9534..736873e963c9 100644 --- a/python/pyspark/profiler.py +++ b/python/pyspark/profiler.py @@ -48,7 +48,7 @@ from pyspark.errors import PySparkRuntimeError if TYPE_CHECKING: - from pyspark.context import SparkContext + from pyspark.core.context import SparkContext MemoryTuple = Tuple[float, float, int] LineProfile = Tuple[int, Optional[MemoryTuple]] diff --git a/python/pyspark/resource/profile.py b/python/pyspark/resource/profile.py index a982f608c196..a22afdf16c8b 100644 --- a/python/pyspark/resource/profile.py +++ b/python/pyspark/resource/profile.py @@ -15,9 +15,7 @@ # limitations under the License. # from threading import RLock -from typing import overload, Dict, Union, Optional - -from py4j.java_gateway import JavaObject +from typing import overload, Dict, Union, Optional, TYPE_CHECKING from pyspark.resource.requests import ( TaskResourceRequest, @@ -26,6 +24,9 @@ ExecutorResourceRequest, ) +if TYPE_CHECKING: + from py4j.java_gateway import JavaObject + class ResourceProfile: @@ -84,7 +85,7 @@ class ResourceProfile: """ @overload - def __init__(self, _java_resource_profile: JavaObject): + def __init__(self, _java_resource_profile: "JavaObject"): ... @overload @@ -98,7 +99,7 @@ def __init__( def __init__( self, - _java_resource_profile: Optional[JavaObject] = None, + _java_resource_profile: Optional["JavaObject"] = None, _exec_req: Optional[Dict[str, ExecutorResourceRequest]] = None, _task_req: Optional[Dict[str, TaskResourceRequest]] = None, ): @@ -200,7 +201,7 @@ class ResourceProfileBuilder: """ def __init__(self) -> None: - from pyspark.context import SparkContext + from pyspark.core.context import SparkContext # TODO: ignore[attr-defined] will be removed, once SparkContext is inlined _jvm = SparkContext._jvm diff --git a/python/pyspark/resource/requests.py b/python/pyspark/resource/requests.py index d3a43d3a06f7..746fca984839 100644 --- a/python/pyspark/resource/requests.py +++ b/python/pyspark/resource/requests.py @@ -14,12 +14,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from typing import overload, Optional, Dict - -from py4j.java_gateway import JavaObject, JVMView +from typing import overload, Optional, Dict, TYPE_CHECKING from pyspark.util import _parse_memory +if TYPE_CHECKING: + from py4j.java_gateway import JavaObject, JVMView + class ExecutorResourceRequest: """ @@ -147,7 +148,7 @@ class ExecutorResourceRequests: _OFFHEAP_MEM = "offHeap" @overload - def __init__(self, _jvm: JVMView): + def __init__(self, _jvm: "JVMView"): ... @overload @@ -160,7 +161,7 @@ def __init__( def __init__( self, - _jvm: Optional[JVMView] = None, + _jvm: Optional["JVMView"] = None, _requests: Optional[Dict[str, ExecutorResourceRequest]] = None, ): from pyspark import SparkContext @@ -445,7 +446,7 @@ class TaskResourceRequests: _CPUS = "cpus" @overload - def __init__(self, _jvm: JVMView): + def __init__(self, _jvm: "JVMView"): ... @overload @@ -458,7 +459,7 @@ def __init__( def __init__( self, - _jvm: Optional[JVMView] = None, + _jvm: Optional["JVMView"] = None, _requests: Optional[Dict[str, TaskResourceRequest]] = None, ): from pyspark import SparkContext @@ -468,7 +469,7 @@ def __init__( if _jvm is not None and not is_remote(): self._java_task_resource_requests: Optional[ - JavaObject + "JavaObject" ] = _jvm.org.apache.spark.resource.TaskResourceRequests() if _requests is not None: for k, v in _requests.items(): diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index d269d55653cf..821e142304aa 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -29,7 +29,7 @@ -------- The serializer is chosen when creating :class:`SparkContext`: ->>> from pyspark.context import SparkContext +>>> from pyspark.core.context import SparkContext >>> from pyspark.serializers import MarshalSerializer >>> sc = SparkContext('local', 'test', serializer=MarshalSerializer()) >>> sc.parallelize(list(range(1000))).map(lambda x: 2 * x).take(10) @@ -67,7 +67,6 @@ pickle_protocol = pickle.HIGHEST_PROTOCOL from pyspark import cloudpickle -from pyspark.util import print_exec __all__ = [ @@ -455,6 +454,8 @@ def loads(self, obj, encoding="bytes"): class CloudPickleSerializer(FramedSerializer): def dumps(self, obj): + from pyspark.util import print_exec + try: return cloudpickle.dumps(obj, pickle_protocol) except pickle.PickleError: diff --git a/python/pyspark/shell.py b/python/pyspark/shell.py index 7e2093c1d31d..f705f0edd8fe 100644 --- a/python/pyspark/shell.py +++ b/python/pyspark/shell.py @@ -29,7 +29,7 @@ import sys import pyspark -from pyspark.context import SparkContext +from pyspark.core.context import SparkContext from pyspark.sql import SparkSession from pyspark.sql.context import SQLContext from pyspark.sql.utils import is_remote diff --git a/python/pyspark/sql/avro/functions.py b/python/pyspark/sql/avro/functions.py index e49953e8953b..5cebfa384045 100644 --- a/python/pyspark/sql/avro/functions.py +++ b/python/pyspark/sql/avro/functions.py @@ -22,8 +22,6 @@ from typing import Dict, Optional, TYPE_CHECKING, cast -from py4j.java_gateway import JVMView - from pyspark.sql.column import Column, _to_java_column from pyspark.sql.utils import get_active_spark_context, try_remote_avro_functions from pyspark.util import _print_missing_jar @@ -79,6 +77,7 @@ def from_avro( >>> avroDf.select(from_avro(avroDf.avro, jsonFormatSchema).alias("value")).collect() [Row(value=Row(avro=Row(age=2, name='Alice')))] """ + from py4j.java_gateway import JVMView sc = get_active_spark_context() try: @@ -128,6 +127,7 @@ def to_avro(data: "ColumnOrName", jsonFormatSchema: str = "") -> Column: >>> df.select(to_avro(df.value, jsonFormatSchema).alias("suite")).collect() [Row(suite=bytearray(b'\\x02\\x00'))] """ + from py4j.java_gateway import JVMView sc = get_active_spark_context() try: diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index bf6192f8c58d..31c1013742a0 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -31,14 +31,13 @@ Union, ) -from py4j.java_gateway import JavaObject, JVMView - -from pyspark.context import SparkContext from pyspark.errors import PySparkAttributeError, PySparkTypeError, PySparkValueError from pyspark.sql.types import DataType from pyspark.sql.utils import get_active_spark_context if TYPE_CHECKING: + from py4j.java_gateway import JavaObject + from pyspark.core.context import SparkContext from pyspark.sql._typing import ColumnOrName, LiteralType, DecimalLiteral, DateTimeLiteral from pyspark.sql.window import WindowSpec @@ -46,16 +45,20 @@ def _create_column_from_literal(literal: Union["LiteralType", "DecimalLiteral"]) -> "Column": + from py4j.java_gateway import JVMView + sc = get_active_spark_context() return cast(JVMView, sc._jvm).functions.lit(literal) def _create_column_from_name(name: str) -> "Column": + from py4j.java_gateway import JVMView + sc = get_active_spark_context() return cast(JVMView, sc._jvm).functions.col(name) -def _to_java_column(col: "ColumnOrName") -> JavaObject: +def _to_java_column(col: "ColumnOrName") -> "JavaObject": if isinstance(col, Column): jcol = col._jc elif isinstance(col, str): @@ -68,29 +71,29 @@ def _to_java_column(col: "ColumnOrName") -> JavaObject: return jcol -def _to_java_expr(col: "ColumnOrName") -> JavaObject: +def _to_java_expr(col: "ColumnOrName") -> "JavaObject": return _to_java_column(col).expr() @overload -def _to_seq(sc: SparkContext, cols: Iterable[JavaObject]) -> JavaObject: +def _to_seq(sc: "SparkContext", cols: Iterable["JavaObject"]) -> "JavaObject": ... @overload def _to_seq( - sc: SparkContext, + sc: "SparkContext", cols: Iterable["ColumnOrName"], - converter: Optional[Callable[["ColumnOrName"], JavaObject]], -) -> JavaObject: + converter: Optional[Callable[["ColumnOrName"], "JavaObject"]], +) -> "JavaObject": ... def _to_seq( - sc: SparkContext, - cols: Union[Iterable["ColumnOrName"], Iterable[JavaObject]], - converter: Optional[Callable[["ColumnOrName"], JavaObject]] = None, -) -> JavaObject: + sc: "SparkContext", + cols: Union[Iterable["ColumnOrName"], Iterable["JavaObject"]], + converter: Optional[Callable[["ColumnOrName"], "JavaObject"]] = None, +) -> "JavaObject": """ Convert a list of Columns (or names) into a JVM Seq of Column. @@ -104,10 +107,10 @@ def _to_seq( def _to_list( - sc: SparkContext, + sc: "SparkContext", cols: List["ColumnOrName"], - converter: Optional[Callable[["ColumnOrName"], JavaObject]] = None, -) -> JavaObject: + converter: Optional[Callable[["ColumnOrName"], "JavaObject"]] = None, +) -> "JavaObject": """ Convert a list of Columns (or names) into a JVM (Scala) List of Columns. @@ -136,6 +139,8 @@ def _(self: "Column") -> "Column": def _func_op(name: str, doc: str = "") -> Callable[["Column"], "Column"]: def _(self: "Column") -> "Column": + from py4j.java_gateway import JVMView + sc = get_active_spark_context() jc = getattr(cast(JVMView, sc._jvm).functions, name)(self._jc) return Column(jc) @@ -150,6 +155,8 @@ def _bin_func_op( doc: str = "binary function", ) -> Callable[["Column", Union["Column", "LiteralType", "DecimalLiteral"]], "Column"]: def _(self: "Column", other: Union["Column", "LiteralType", "DecimalLiteral"]) -> "Column": + from py4j.java_gateway import JVMView + sc = get_active_spark_context() fn = getattr(cast(JVMView, sc._jvm).functions, name) jc = other._jc if isinstance(other, Column) else _create_column_from_literal(other) @@ -226,7 +233,7 @@ class Column: Column<...> """ - def __init__(self, jc: JavaObject) -> None: + def __init__(self, jc: "JavaObject") -> None: self._jc = jc # arithmetic operators diff --git a/python/pyspark/sql/conf.py b/python/pyspark/sql/conf.py index dd43991b0706..b718f779a179 100644 --- a/python/pyspark/sql/conf.py +++ b/python/pyspark/sql/conf.py @@ -16,14 +16,15 @@ # import sys -from typing import Any, Dict, Optional, Union - -from py4j.java_gateway import JavaObject +from typing import Any, Dict, Optional, Union, TYPE_CHECKING from pyspark import _NoValue from pyspark._globals import _NoValueType from pyspark.errors import PySparkTypeError +if TYPE_CHECKING: + from py4j.java_gateway import JavaObject + class RuntimeConfig: """User-facing configuration API, accessible through `SparkSession.conf`. @@ -34,7 +35,7 @@ class RuntimeConfig: Supports Spark Connect. """ - def __init__(self, jconf: JavaObject) -> None: + def __init__(self, jconf: "JavaObject") -> None: """Create a new RuntimeConfig that wraps the underlying JVM object.""" self._jconf = jconf diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py index c8cf12f40708..0cd6758e783e 100644 --- a/python/pyspark/sql/connect/client/core.py +++ b/python/pyspark/sql/connect/client/core.py @@ -90,7 +90,7 @@ from pyspark.sql.connect.utils import get_python_ver from pyspark.sql.pandas.types import _create_converter_to_pandas, from_arrow_schema from pyspark.sql.types import DataType, StructType, TimestampType, _has_type -from pyspark.rdd import PythonEvalType +from pyspark.util import PythonEvalType from pyspark.storagelevel import StorageLevel from pyspark.errors import PySparkValueError, PySparkAssertionError, PySparkNotImplementedError diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index 820a15429ecf..576c196dbd2b 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -64,7 +64,7 @@ PySparkNotImplementedError, PySparkRuntimeError, ) -from pyspark.rdd import PythonEvalType +from pyspark.util import PythonEvalType from pyspark.storagelevel import StorageLevel import pyspark.sql.connect.plan as plan from pyspark.sql.connect.conversion import ArrowTableToRowsConversion diff --git a/python/pyspark/sql/connect/group.py b/python/pyspark/sql/connect/group.py index db470011527b..088bb000a344 100644 --- a/python/pyspark/sql/connect/group.py +++ b/python/pyspark/sql/connect/group.py @@ -31,7 +31,7 @@ cast, ) -from pyspark.rdd import PythonEvalType +from pyspark.util import PythonEvalType from pyspark.sql.group import GroupedData as PySparkGroupedData from pyspark.sql.pandas.group_ops import PandasCogroupedOps as PySparkPandasCogroupedOps from pyspark.sql.types import NumericType diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index 8b7e403667cf..13cad30bbff9 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -48,7 +48,6 @@ ) import urllib -from pyspark import SparkContext, SparkConf, __version__ from pyspark.loose_version import LooseVersion from pyspark.sql.connect.client import SparkConnectClient, DefaultChannelBuilder from pyspark.sql.connect.conf import RuntimeConf @@ -889,6 +888,8 @@ def _start_connect_server(master: str, opts: Dict[str, Any]) -> None: 3. Starts a regular Spark session that automatically starts a Spark Connect server with JVM via ``spark.plugins`` feature. """ + from pyspark import SparkContext, SparkConf, __version__ + session = PySparkSession._instantiatedSession if session is None or session._sc._jsc is None: # Configurations to be overwritten diff --git a/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py b/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py index 022e768c43b6..c4cf52b9996d 100644 --- a/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py +++ b/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py @@ -21,7 +21,7 @@ """ import os -from pyspark.java_gateway import local_connect_and_auth +from pyspark.util import local_connect_and_auth from pyspark.serializers import ( write_int, read_long, diff --git a/python/pyspark/sql/connect/streaming/worker/listener_worker.py b/python/pyspark/sql/connect/streaming/worker/listener_worker.py index bb6bcd5d9659..69e0d8a46248 100644 --- a/python/pyspark/sql/connect/streaming/worker/listener_worker.py +++ b/python/pyspark/sql/connect/streaming/worker/listener_worker.py @@ -22,7 +22,7 @@ import os import json -from pyspark.java_gateway import local_connect_and_auth +from pyspark.util import local_connect_and_auth from pyspark.serializers import ( read_int, write_int, diff --git a/python/pyspark/sql/connect/udf.py b/python/pyspark/sql/connect/udf.py index 1c42f4d74b7a..f3aa719b2bb6 100644 --- a/python/pyspark/sql/connect/udf.py +++ b/python/pyspark/sql/connect/udf.py @@ -27,7 +27,7 @@ from inspect import getfullargspec from typing import cast, Callable, Any, List, TYPE_CHECKING, Optional, Union -from pyspark.rdd import PythonEvalType +from pyspark.util import PythonEvalType from pyspark.sql.connect.expressions import ( ColumnReference, CommonInlineUserDefinedFunction, diff --git a/python/pyspark/sql/connect/udtf.py b/python/pyspark/sql/connect/udtf.py index f137864e026e..4ee39dc89b8e 100644 --- a/python/pyspark/sql/connect/udtf.py +++ b/python/pyspark/sql/connect/udtf.py @@ -24,7 +24,7 @@ import warnings from typing import List, Type, TYPE_CHECKING, Optional, Union -from pyspark.rdd import PythonEvalType +from pyspark.util import PythonEvalType from pyspark.sql.connect.column import Column from pyspark.sql.connect.expressions import ColumnReference, Expression, NamedArgumentExpression from pyspark.sql.connect.plan import ( diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 7ef7b320eeb4..cbb0299e2195 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -32,8 +32,6 @@ cast, ) -from py4j.java_gateway import JavaObject - from pyspark import _NoValue from pyspark._globals import _NoValueType from pyspark.sql.session import _monkey_patch_RDD, SparkSession @@ -43,12 +41,13 @@ from pyspark.sql.udf import UDFRegistration # noqa: F401 from pyspark.sql.udtf import UDTFRegistration from pyspark.errors.exceptions.captured import install_exception_handler -from pyspark.context import SparkContext -from pyspark.rdd import RDD from pyspark.sql.types import AtomicType, DataType, StructType from pyspark.sql.streaming import StreamingQueryManager if TYPE_CHECKING: + from py4j.java_gateway import JavaObject + from pyspark.core.rdd import RDD + from pyspark.core.context import SparkContext from pyspark.sql._typing import ( AtomicValue, RowLike, @@ -104,9 +103,9 @@ class SQLContext: def __init__( self, - sparkContext: SparkContext, + sparkContext: "SparkContext", sparkSession: Optional[SparkSession] = None, - jsqlContext: Optional[JavaObject] = None, + jsqlContext: Optional["JavaObject"] = None, ): if sparkSession is None: warnings.warn( @@ -132,7 +131,7 @@ def __init__( SQLContext._instantiatedContext = self @property - def _ssql_ctx(self) -> JavaObject: + def _ssql_ctx(self) -> "JavaObject": """Accessor for the JVM Spark SQL context. Subclasses can override this property to provide their own @@ -141,7 +140,7 @@ def _ssql_ctx(self) -> JavaObject: return self._jsqlContext @classmethod - def getOrCreate(cls: Type["SQLContext"], sc: SparkContext) -> "SQLContext": + def getOrCreate(cls: Type["SQLContext"], sc: "SparkContext") -> "SQLContext": """ Get the existing SQLContext or create a new one with given SparkContext. @@ -162,7 +161,7 @@ def getOrCreate(cls: Type["SQLContext"], sc: SparkContext) -> "SQLContext": @classmethod def _get_or_create( - cls: Type["SQLContext"], sc: SparkContext, **static_conf: Any + cls: Type["SQLContext"], sc: "SparkContext", **static_conf: Any ) -> "SQLContext": if ( cls._instantiatedContext is None @@ -359,7 +358,7 @@ def createDataFrame( def createDataFrame( # type: ignore[misc] self, - data: Union[RDD[Any], Iterable[Any], "PandasDataFrameLike"], + data: Union["RDD[Any]", Iterable[Any], "PandasDataFrameLike"], schema: Optional[Union[AtomicType, StructType, str]] = None, samplingRatio: Optional[float] = None, verifySchema: bool = True, @@ -714,9 +713,9 @@ class HiveContext(SQLContext): def __init__( self, - sparkContext: SparkContext, + sparkContext: "SparkContext", sparkSession: Optional[SparkSession] = None, - jhiveContext: Optional[JavaObject] = None, + jhiveContext: Optional["JavaObject"] = None, ): warnings.warn( "HiveContext is deprecated in Spark 2.0.0. Please use " @@ -734,12 +733,12 @@ def __init__( @classmethod def _get_or_create( - cls: Type["SQLContext"], sc: SparkContext, **static_conf: Any + cls: Type["SQLContext"], sc: "SparkContext", **static_conf: Any ) -> "SQLContext": return SQLContext._get_or_create(sc, **HiveContext._static_conf) @classmethod - def _createForTesting(cls, sparkContext: SparkContext) -> "HiveContext": + def _createForTesting(cls, sparkContext: "SparkContext") -> "HiveContext": """(Internal use only) Create a new HiveContext for testing. All test code that touches HiveContext *must* go through this method. Otherwise, @@ -765,7 +764,7 @@ def _test() -> None: import os import doctest import tempfile - from pyspark.context import SparkContext + from pyspark.core.context import SparkContext from pyspark.sql import Row, SQLContext import pyspark.sql.context diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index d04c35dac5e9..3266c4135a76 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -38,19 +38,16 @@ TYPE_CHECKING, ) -from py4j.java_gateway import JavaObject, JVMView - from pyspark import _NoValue from pyspark._globals import _NoValueType -from pyspark.context import SparkContext from pyspark.errors import ( PySparkTypeError, PySparkValueError, PySparkIndexError, PySparkAttributeError, ) -from pyspark.rdd import ( - RDD, +from pyspark.util import ( + is_remote_only, _load_from_socket, _local_iterator_from_socket, ) @@ -70,6 +67,9 @@ from pyspark.sql.pandas.map_ops import PandasMapOpsMixin if TYPE_CHECKING: + from py4j.java_gateway import JavaObject + from pyspark.core.rdd import RDD + from pyspark.core.context import SparkContext from pyspark._typing import PrimitiveType from pyspark.pandas.frame import DataFrame as PandasOnSparkDataFrame from pyspark.sql._typing import ( @@ -141,7 +141,7 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): def __init__( self, - jdf: JavaObject, + jdf: "JavaObject", sql_ctx: Union["SQLContext", "SparkSession"], ): from pyspark.sql.context import SQLContext @@ -161,12 +161,12 @@ def __init__( session = sql_ctx self._session: "SparkSession" = session - self._sc: SparkContext = sql_ctx._sc - self._jdf: JavaObject = jdf + self._sc: "SparkContext" = sql_ctx._sc + self._jdf: "JavaObject" = jdf self.is_cached = False # initialized lazily self._schema: Optional[StructType] = None - self._lazy_rdd: Optional[RDD[Row]] = None + self._lazy_rdd: Optional["RDD[Row]"] = None # Check whether _repr_html is supported or not, we use it to avoid calling _jdf twice # by __repr__ and _repr_html_ while eager evaluation opens. self._support_repr_html = False @@ -204,28 +204,32 @@ def sparkSession(self) -> "SparkSession": """ return self._session - @property - def rdd(self) -> "RDD[Row]": - """Returns the content as an :class:`pyspark.RDD` of :class:`Row`. + if not is_remote_only(): - .. versionadded:: 1.3.0 + @property + def rdd(self) -> "RDD[Row]": + """Returns the content as an :class:`pyspark.RDD` of :class:`Row`. - Returns - ------- - :class:`RDD` + .. versionadded:: 1.3.0 - Examples - -------- - >>> df = spark.range(1) - >>> type(df.rdd) - - """ - if self._lazy_rdd is None: - jrdd = self._jdf.javaToPython() - self._lazy_rdd = RDD( - jrdd, self.sparkSession._sc, BatchedSerializer(CPickleSerializer()) - ) - return self._lazy_rdd + Returns + ------- + :class:`RDD` + + Examples + -------- + >>> df = spark.range(1) + >>> type(df.rdd) + + """ + from pyspark.core.rdd import RDD + + if self._lazy_rdd is None: + jrdd = self._jdf.javaToPython() + self._lazy_rdd = RDD( + jrdd, self.sparkSession._sc, BatchedSerializer(CPickleSerializer()) + ) + return self._lazy_rdd @property def na(self) -> "DataFrameNaFunctions": @@ -281,30 +285,34 @@ def stat(self) -> "DataFrameStatFunctions": """ return DataFrameStatFunctions(self) - def toJSON(self, use_unicode: bool = True) -> RDD[str]: - """Converts a :class:`DataFrame` into a :class:`RDD` of string. + if not is_remote_only(): - Each row is turned into a JSON document as one element in the returned RDD. + def toJSON(self, use_unicode: bool = True) -> "RDD[str]": + """Converts a :class:`DataFrame` into a :class:`RDD` of string. - .. versionadded:: 1.3.0 + Each row is turned into a JSON document as one element in the returned RDD. - Parameters - ---------- - use_unicode : bool, optional, default True - Whether to convert to unicode or not. + .. versionadded:: 1.3.0 - Returns - ------- - :class:`RDD` + Parameters + ---------- + use_unicode : bool, optional, default True + Whether to convert to unicode or not. - Examples - -------- - >>> df = spark.createDataFrame([(2, "Alice"), (5, "Bob")], schema=["age", "name"]) - >>> df.toJSON().first() - '{"age":2,"name":"Alice"}' - """ - rdd = self._jdf.toJSON() - return RDD(rdd.toJavaRDD(), self._sc, UTF8Deserializer(use_unicode)) + Returns + ------- + :class:`RDD` + + Examples + -------- + >>> df = spark.createDataFrame([(2, "Alice"), (5, "Bob")], schema=["age", "name"]) + >>> df.toJSON().first() + '{"age":2,"name":"Alice"}' + """ + from pyspark.core.rdd import RDD + + rdd = self._jdf.toJSON() + return RDD(rdd.toJavaRDD(), self._sc, UTF8Deserializer(use_unicode)) def registerTempTable(self, name: str) -> None: """Registers this :class:`DataFrame` as a temporary table using the given name. @@ -3275,16 +3283,16 @@ def sort( def _jseq( self, cols: Sequence, - converter: Optional[Callable[..., Union["PrimitiveType", JavaObject]]] = None, - ) -> JavaObject: + converter: Optional[Callable[..., Union["PrimitiveType", "JavaObject"]]] = None, + ) -> "JavaObject": """Return a JVM Seq of Columns from a list of Column or names""" return _to_seq(self.sparkSession._sc, cols, converter) - def _jmap(self, jm: Dict) -> JavaObject: + def _jmap(self, jm: Dict) -> "JavaObject": """Return a JVM Scala Map from a dict""" return _to_scala_map(self.sparkSession._sc, jm) - def _jcols(self, *cols: "ColumnOrName") -> JavaObject: + def _jcols(self, *cols: "ColumnOrName") -> "JavaObject": """Return a JVM Seq of Columns from a list of Column or column names If `cols` has only one list in it, cols[0] will be used as the list. @@ -3293,7 +3301,7 @@ def _jcols(self, *cols: "ColumnOrName") -> JavaObject: cols = cols[0] return self._jseq(cols, _to_java_column) - def _jcols_ordinal(self, *cols: "ColumnOrNameOrOrdinal") -> JavaObject: + def _jcols_ordinal(self, *cols: "ColumnOrNameOrOrdinal") -> "JavaObject": """Return a JVM Seq of Columns from a list of Column or column names or column ordinals. If `cols` has only one list in it, cols[0] will be used as the list. @@ -3318,7 +3326,7 @@ def _sort_cols( self, cols: Sequence[Union[int, str, Column, List[Union[int, str, Column]]]], kwargs: Dict[str, Any], - ) -> JavaObject: + ) -> "JavaObject": """Return a JVM Seq of Columns that describes the sort order""" if not cols: raise PySparkValueError( @@ -4473,7 +4481,7 @@ def unpivot( def to_jcols( cols: Union["ColumnOrName", List["ColumnOrName"], Tuple["ColumnOrName", ...]] - ) -> JavaObject: + ) -> "JavaObject": if isinstance(cols, list): return self._jcols(*cols) if isinstance(cols, tuple): @@ -6392,6 +6400,8 @@ def withMetadata(self, columnName: str, metadata: Dict[str, Any]) -> "DataFrame" >>> df_meta.schema['age'].metadata {'foo': 'bar'} """ + from py4j.java_gateway import JVMView + if not isinstance(metadata, dict): raise PySparkTypeError( error_class="NOT_DICT", @@ -6542,7 +6552,7 @@ def drop(self, *cols: "ColumnOrName") -> "DataFrame": # type: ignore[misc] +---+-----+-----+ """ column_names: List[str] = [] - java_columns: List[JavaObject] = [] + java_columns: List["JavaObject"] = [] for c in cols: if isinstance(c, str): @@ -6915,7 +6925,7 @@ def pandas_api( return PandasOnSparkDataFrame(internal) -def _to_scala_map(sc: SparkContext, jm: Dict) -> JavaObject: +def _to_scala_map(sc: "SparkContext", jm: Dict) -> "JavaObject": """ Convert a dict into a JVM Map. """ diff --git a/python/pyspark/sql/functions/builtin.py b/python/pyspark/sql/functions/builtin.py index c8a9a8975159..97e886ab6a2c 100644 --- a/python/pyspark/sql/functions/builtin.py +++ b/python/pyspark/sql/functions/builtin.py @@ -39,9 +39,6 @@ ValuesView, ) -from py4j.java_gateway import JVMView - -from pyspark import SparkContext from pyspark.errors import PySparkTypeError, PySparkValueError from pyspark.sql.column import Column, _to_java_column, _to_seq, _create_column_from_literal from pyspark.sql.dataframe import DataFrame @@ -65,6 +62,7 @@ ) if TYPE_CHECKING: + from pyspark import SparkContext from pyspark.sql._typing import ( ColumnOrName, ColumnOrName_, @@ -82,7 +80,7 @@ # since it requires making every single overridden definition. -def _get_jvm_function(name: str, sc: SparkContext) -> Callable: +def _get_jvm_function(name: str, sc: "SparkContext") -> Callable: """ Retrieves JVM function identified by name from Java gateway associated with sc. @@ -96,6 +94,8 @@ def _invoke_function(name: str, *args: Any) -> Column: Invokes JVM function identified by name with args and wraps the result with :class:`~pyspark.sql.Column`. """ + from pyspark import SparkContext + assert SparkContext._active_spark_context is not None jf = _get_jvm_function(name, SparkContext._active_spark_context) return Column(jf(*args)) @@ -5142,6 +5142,7 @@ def broadcast(df: DataFrame) -> DataFrame: | 2| 2| +-----+---+ """ + from py4j.java_gateway import JVMView sc = _get_active_spark_context() return DataFrame(cast(JVMView, sc._jvm).functions.broadcast(df._jdf), df.sparkSession) @@ -17460,6 +17461,8 @@ def _unresolved_named_lambda_variable(*name_parts: Any) -> Column: ---------- name_parts : str """ + from py4j.java_gateway import JVMView + sc = _get_active_spark_context() name_parts_seq = _to_seq(sc, name_parts) expressions = cast(JVMView, sc._jvm).org.apache.spark.sql.catalyst.expressions @@ -17508,6 +17511,8 @@ def _create_lambda(f: Callable) -> Callable: - (Column, Column) -> Column: ... - (Column, Column, Column) -> Column: ... """ + from py4j.java_gateway import JVMView + parameters = _get_lambda_parameters(f) sc = _get_active_spark_context() @@ -17551,6 +17556,8 @@ def _invoke_higher_order_function( :return: a Column """ + from py4j.java_gateway import JVMView + sc = _get_active_spark_context() expressions = cast(JVMView, sc._jvm).org.apache.spark.sql.catalyst.expressions expr = getattr(expressions, name) diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index fe74fe372a2a..15934c24b9d4 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -19,14 +19,13 @@ from typing import Callable, List, Optional, TYPE_CHECKING, overload, Dict, Union, cast, Tuple -from py4j.java_gateway import JavaObject - from pyspark.sql.column import Column, _to_seq from pyspark.sql.session import SparkSession from pyspark.sql.dataframe import DataFrame from pyspark.sql.pandas.group_ops import PandasGroupedOpsMixin if TYPE_CHECKING: + from py4j.java_gateway import JavaObject from pyspark.sql._typing import LiteralType __all__ = ["GroupedData"] @@ -65,7 +64,7 @@ class GroupedData(PandasGroupedOpsMixin): Supports Spark Connect. """ - def __init__(self, jgd: JavaObject, df: DataFrame): + def __init__(self, jgd: "JavaObject", df: DataFrame): self._jgd = jgd self._df = df self.session: SparkSession = df.sparkSession diff --git a/python/pyspark/sql/observation.py b/python/pyspark/sql/observation.py index f12d1250cba2..1dae5086e3dd 100644 --- a/python/pyspark/sql/observation.py +++ b/python/pyspark/sql/observation.py @@ -15,9 +15,7 @@ # limitations under the License. # import os -from typing import Any, Dict, Optional - -from py4j.java_gateway import JavaObject, JVMView +from typing import Any, Dict, Optional, TYPE_CHECKING from pyspark.errors import PySparkTypeError, PySparkValueError, PySparkAssertionError from pyspark.sql import column @@ -25,6 +23,10 @@ from pyspark.sql.dataframe import DataFrame from pyspark.sql.utils import is_remote +if TYPE_CHECKING: + from py4j.java_gateway import JavaObject, JVMView + + __all__ = ["Observation"] @@ -97,7 +99,7 @@ def __init__(self, name: Optional[str] = None) -> None: ) self._name = name self._jvm: Optional[JVMView] = None - self._jo: Optional[JavaObject] = None + self._jo: Optional["JavaObject"] = None def _on(self, df: DataFrame, *exprs: Column) -> DataFrame: """Attaches this observation to the given :class:`DataFrame` to observe aggregations. @@ -149,7 +151,7 @@ def get(self) -> Dict[str, Any]: def _test() -> None: import doctest import sys - from pyspark.context import SparkContext + from pyspark.core.context import SparkContext from pyspark.sql import SparkSession import pyspark.sql.observation diff --git a/python/pyspark/sql/pandas/conversion.py b/python/pyspark/sql/pandas/conversion.py index 891bab63b3da..4f137c7004c1 100644 --- a/python/pyspark/sql/pandas/conversion.py +++ b/python/pyspark/sql/pandas/conversion.py @@ -29,7 +29,7 @@ from pyspark.errors.exceptions.captured import unwrap_spark_exception from pyspark.loose_version import LooseVersion -from pyspark.rdd import _load_from_socket +from pyspark.util import _load_from_socket from pyspark.sql.pandas.serializers import ArrowCollectSerializer from pyspark.sql.pandas.types import _dedup_names from pyspark.sql.types import ( diff --git a/python/pyspark/sql/pandas/functions.py b/python/pyspark/sql/pandas/functions.py index 3ca4a8743d0d..62d365a3b2a1 100644 --- a/python/pyspark/sql/pandas/functions.py +++ b/python/pyspark/sql/pandas/functions.py @@ -20,7 +20,7 @@ from inspect import getfullargspec, signature from typing import get_type_hints -from pyspark.rdd import PythonEvalType +from pyspark.util import PythonEvalType from pyspark.sql.pandas.typehints import infer_eval_type from pyspark.sql.pandas.utils import require_minimum_pandas_version, require_minimum_pyarrow_version from pyspark.sql.types import DataType diff --git a/python/pyspark/sql/pandas/functions.pyi b/python/pyspark/sql/pandas/functions.pyi index 5a2af7a4fed0..b053b93a278e 100644 --- a/python/pyspark/sql/pandas/functions.pyi +++ b/python/pyspark/sql/pandas/functions.pyi @@ -37,7 +37,7 @@ from pyspark.sql.pandas._typing import ( ) from pyspark import since as since # noqa: F401 -from pyspark.rdd import PythonEvalType as PythonEvalType # noqa: F401 +from pyspark.util import PythonEvalType as PythonEvalType # noqa: F401 from pyspark.sql.types import ArrayType, StructType class PandasUDFType: diff --git a/python/pyspark/sql/pandas/group_ops.py b/python/pyspark/sql/pandas/group_ops.py index b71d4a2a0d8b..d5b214e2f7d5 100644 --- a/python/pyspark/sql/pandas/group_ops.py +++ b/python/pyspark/sql/pandas/group_ops.py @@ -19,7 +19,7 @@ import warnings from pyspark.errors import PySparkValueError -from pyspark.rdd import PythonEvalType +from pyspark.util import PythonEvalType from pyspark.sql.column import Column from pyspark.sql.dataframe import DataFrame from pyspark.sql.streaming.state import GroupStateTimeout diff --git a/python/pyspark/sql/pandas/map_ops.py b/python/pyspark/sql/pandas/map_ops.py index 25548a8b3957..8c2795a8fbe4 100644 --- a/python/pyspark/sql/pandas/map_ops.py +++ b/python/pyspark/sql/pandas/map_ops.py @@ -17,14 +17,13 @@ import sys from typing import Union, TYPE_CHECKING, Optional -from py4j.java_gateway import JavaObject - from pyspark.resource.requests import ExecutorResourceRequests, TaskResourceRequests -from pyspark.rdd import PythonEvalType from pyspark.resource import ResourceProfile +from pyspark.util import PythonEvalType from pyspark.sql.types import StructType if TYPE_CHECKING: + from py4j.java_gateway import JavaObject from pyspark.sql.dataframe import DataFrame from pyspark.sql.pandas._typing import PandasMapIterFunction, ArrowMapIterFunction @@ -252,7 +251,7 @@ def mapInArrow( def _build_java_profile( self, profile: Optional[ResourceProfile] = None - ) -> Optional[JavaObject]: + ) -> Optional["JavaObject"]: """Build the java ResourceProfile based on PySpark ResourceProfile""" from pyspark.sql import DataFrame diff --git a/python/pyspark/sql/pandas/typehints.py b/python/pyspark/sql/pandas/typehints.py index 37ba02a94d58..c51e680329be 100644 --- a/python/pyspark/sql/pandas/typehints.py +++ b/python/pyspark/sql/pandas/typehints.py @@ -32,7 +32,7 @@ def infer_eval_type( sig: Signature, type_hints: Dict[str, Any] ) -> Union["PandasScalarUDFType", "PandasScalarIterUDFType", "PandasGroupedAggUDFType"]: """ - Infers the evaluation type in :class:`pyspark.rdd.PythonEvalType` from + Infers the evaluation type in :class:`pyspark.util.PythonEvalType` from :class:`inspect.Signature` instance and type hints. """ from pyspark.sql.pandas.functions import PandasUDFType diff --git a/python/pyspark/sql/protobuf/functions.py b/python/pyspark/sql/protobuf/functions.py index 5a99aed55f74..63871e437571 100644 --- a/python/pyspark/sql/protobuf/functions.py +++ b/python/pyspark/sql/protobuf/functions.py @@ -22,8 +22,6 @@ from typing import Dict, Optional, TYPE_CHECKING, cast -from py4j.java_gateway import JVMView - from pyspark.sql.column import Column, _to_java_column from pyspark.sql.utils import get_active_spark_context, try_remote_protobuf_functions from pyspark.util import _print_missing_jar @@ -140,6 +138,7 @@ def from_protobuf( |{1668035962, 2020}| +------------------+ """ + from py4j.java_gateway import JVMView sc = get_active_spark_context() try: @@ -260,6 +259,7 @@ def to_protobuf( |[08 FA EA B0 9B 06 10 E4 0F]| +----------------------------+ """ + from py4j.java_gateway import JVMView sc = get_active_spark_context() try: diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 9eb5d99dfa4c..26fe8c5e6fa2 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -17,9 +17,7 @@ import sys from typing import cast, overload, Dict, Iterable, List, Optional, Tuple, TYPE_CHECKING, Union -from py4j.java_gateway import JavaClass, JavaObject - -from pyspark import RDD +from pyspark.util import is_remote_only from pyspark.sql.column import _to_seq, _to_java_column, Column from pyspark.sql.types import StructType from pyspark.sql import utils @@ -27,6 +25,8 @@ from pyspark.errors import PySparkTypeError, PySparkValueError if TYPE_CHECKING: + from py4j.java_gateway import JavaObject + from pyspark.core.rdd import RDD from pyspark.sql._typing import OptionalPrimitiveType, ColumnOrName from pyspark.sql.session import SparkSession from pyspark.sql.dataframe import DataFrame @@ -70,7 +70,7 @@ def __init__(self, spark: "SparkSession"): self._jreader = spark._jsparkSession.read() self._spark = spark - def _df(self, jdf: JavaObject) -> "DataFrame": + def _df(self, jdf: "JavaObject") -> "DataFrame": from pyspark.sql.dataframe import DataFrame return DataFrame(jdf, self._spark) @@ -320,7 +320,7 @@ def load( def json( self, - path: Union[str, List[str], RDD[str]], + path: Union[str, List[str], "RDD[str]"], schema: Optional[Union[StructType, str]] = None, primitivesAsString: Optional[Union[bool, str]] = None, prefersDecimal: Optional[Union[bool, str]] = None, @@ -465,7 +465,11 @@ def json( if type(path) == list: assert self._spark._sc._jvm is not None return self._df(self._jreader.json(self._spark._sc._jvm.PythonUtils.toSeq(path))) - elif isinstance(path, RDD): + + if not is_remote_only(): + from pyspark.core.rdd import RDD # noqa: F401 + + if not is_remote_only() and isinstance(path, RDD): def func(iterator: Iterable) -> Iterable: for x in iterator: @@ -829,7 +833,11 @@ def csv( if type(path) == list: assert self._spark._sc._jvm is not None return self._df(self._jreader.csv(self._spark._sc._jvm.PythonUtils.toSeq(path))) - elif isinstance(path, RDD): + + if not is_remote_only(): + from pyspark.core.rdd import RDD # noqa: F401 + + if not is_remote_only() and isinstance(path, RDD): def func(iterator): for x in iterator: @@ -861,7 +869,7 @@ def func(iterator): def xml( self, - path: Union[str, List[str], RDD[str]], + path: Union[str, List[str], "RDD[str]"], rowTag: Optional[str] = None, schema: Optional[Union[StructType, str]] = None, excludeAttribute: Optional[Union[bool, str]] = None, @@ -952,7 +960,11 @@ def xml( if type(path) == list: assert self._spark._sc._jvm is not None return self._df(self._jreader.xml(self._spark._sc._jvm.PythonUtils.toSeq(path))) - elif isinstance(path, RDD): + + if not is_remote_only(): + from pyspark.core.rdd import RDD # noqa: F401 + + if not is_remote_only() and isinstance(path, RDD): def func(iterator: Iterable) -> Iterable: for x in iterator: @@ -1132,6 +1144,8 @@ def jdbc( ------- :class:`DataFrame` """ + from py4j.java_gateway import JavaClass + if properties is None: properties = dict() assert self._spark._sc._gateway is not None @@ -1177,7 +1191,7 @@ def __init__(self, df: "DataFrame"): self._spark = df.sparkSession self._jwrite = df._jdf.write() - def _sq(self, jsq: JavaObject) -> "StreamingQuery": + def _sq(self, jsq: "JavaObject") -> "StreamingQuery": from pyspark.sql.streaming import StreamingQuery return StreamingQuery(jsq) @@ -2269,6 +2283,8 @@ def jdbc( Don't create too many partitions in parallel on a large cluster; otherwise Spark might crash your external database systems. """ + from py4j.java_gateway import JavaClass + if properties is None: properties = dict() @@ -2439,7 +2455,7 @@ def _test() -> None: import doctest import os import py4j - from pyspark.context import SparkContext + from pyspark.core.context import SparkContext from pyspark.sql import SparkSession import pyspark.sql.readwriter diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 4a8a653fd466..f065a106bbf2 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -38,10 +38,7 @@ TYPE_CHECKING, ) -from py4j.java_gateway import JavaObject - -from pyspark import SparkConf, SparkContext -from pyspark.rdd import RDD +from pyspark.util import is_remote_only from pyspark.sql.column import _to_java_column from pyspark.sql.conf import RuntimeConfig from pyspark.sql.dataframe import DataFrame @@ -69,6 +66,10 @@ from pyspark.errors import PySparkValueError, PySparkTypeError, PySparkRuntimeError if TYPE_CHECKING: + from py4j.java_gateway import JavaObject + from pyspark.core.conf import SparkConf + from pyspark.core.context import SparkContext + from pyspark.core.rdd import RDD from pyspark.sql._typing import AtomicValue, RowLike, OptionalPrimitiveType from pyspark.sql.catalog import Catalog from pyspark.sql.pandas._typing import ArrayLike, DataFrameLike as PandasDataFrameLike @@ -129,7 +130,10 @@ def toDF(self, schema=None, sampleRatio=None): """ return sparkSession.createDataFrame(self, schema, sampleRatio) - RDD.toDF = toDF # type: ignore[method-assign] + if not is_remote_only(): + from pyspark import RDD + + RDD.toDF = toDF # type: ignore[method-assign] # TODO(SPARK-38912): This method can be dropped once support for Python 3.8 is dropped @@ -216,7 +220,7 @@ def __init__(self) -> None: self._options: Dict[str, Any] = {} @overload - def config(self, *, conf: SparkConf) -> "SparkSession.Builder": + def config(self, *, conf: "SparkConf") -> "SparkSession.Builder": ... @overload @@ -231,7 +235,7 @@ def config( self, key: Optional[str] = None, value: Optional[Any] = None, - conf: Optional[SparkConf] = None, + conf: Optional["SparkConf"] = None, *, map: Optional[Dict[str, "OptionalPrimitiveType"]] = None, ) -> "SparkSession.Builder": @@ -268,7 +272,7 @@ def config( -------- For an existing :class:`SparkConf`, use `conf` parameter. - >>> from pyspark.conf import SparkConf + >>> from pyspark.core.conf import SparkConf >>> conf = SparkConf().setAppName("example").setMaster("local") >>> SparkSession.builder.config(conf=conf) "SparkSession": >>> s1.conf.get("k2") == s2.conf.get("k2") == "v2" True """ - from pyspark.context import SparkContext - from pyspark.conf import SparkConf - opts = dict(self._options) + if is_remote_only(): + from pyspark.sql.connect.session import SparkSession as RemoteSparkSession + + url = opts.get("spark.remote", os.environ.get("SPARK_REMOTE")) + + if url is None: + raise PySparkRuntimeError( + error_class="CONNECT_URL_NOT_SET", + message_parameters={}, + ) + + os.environ["SPARK_CONNECT_MODE_ENABLED"] = "1" + opts["spark.remote"] = url + return RemoteSparkSession.builder.config(map=opts).getOrCreate() # type: ignore + + from pyspark.core.context import SparkContext + from pyspark.core.conf import SparkConf + with self._lock: if ( "SPARK_CONNECT_MODE_ENABLED" in os.environ @@ -596,8 +615,8 @@ def create(self) -> "SparkSession": def __init__( self, - sparkContext: SparkContext, - jsparkSession: Optional[JavaObject] = None, + sparkContext: "SparkContext", + jsparkSession: Optional["JavaObject"] = None, options: Dict[str, Any] = {}, ): self._sc = sparkContext @@ -655,25 +674,27 @@ def _jconf(self) -> "JavaObject": """Accessor for the JVM SQL-specific configurations""" return self._jsparkSession.sessionState().conf() - def newSession(self) -> "SparkSession": - """ - Returns a new :class:`SparkSession` as new session, that has separate SQLConf, - registered temporary views and UDFs, but shared :class:`SparkContext` and - table cache. + if not is_remote_only(): - .. versionadded:: 2.0.0 + def newSession(self) -> "SparkSession": + """ + Returns a new :class:`SparkSession` as new session, that has separate SQLConf, + registered temporary views and UDFs, but shared :class:`SparkContext` and + table cache. - Returns - ------- - :class:`SparkSession` - Spark session if an active session exists for the current thread + .. versionadded:: 2.0.0 - Examples - -------- - >>> spark.newSession() - <...SparkSession object ...> - """ - return self.__class__(self._sc, self._jsparkSession.newSession()) + Returns + ------- + :class:`SparkSession` + Spark session if an active session exists for the current thread + + Examples + -------- + >>> spark.newSession() + <...SparkSession object ...> + """ + return self.__class__(self._sc, self._jsparkSession.newSession()) @classmethod @try_remote_session_classmethod @@ -739,29 +760,31 @@ def active(cls) -> "SparkSession": ) return session - @property - def sparkContext(self) -> SparkContext: - """ - Returns the underlying :class:`SparkContext`. + if not is_remote_only(): - .. versionadded:: 2.0.0 + @property + def sparkContext(self) -> "SparkContext": + """ + Returns the underlying :class:`SparkContext`. - Returns - ------- - :class:`SparkContext` + .. versionadded:: 2.0.0 - Examples - -------- - >>> spark.sparkContext - + Returns + ------- + :class:`SparkContext` + + Examples + -------- + >>> spark.sparkContext + - Create an RDD from the Spark context + Create an RDD from the Spark context - >>> rdd = spark.sparkContext.parallelize([1, 2, 3]) - >>> rdd.collect() - [1, 2, 3] - """ - return self._sc + >>> rdd = spark.sparkContext.parallelize([1, 2, 3]) + >>> rdd.collect() + [1, 2, 3] + """ + return self._sc @property def version(self) -> str: @@ -1040,7 +1063,7 @@ def _inferSchemaFromList( def _inferSchema( self, - rdd: RDD[Any], + rdd: "RDD[Any]", samplingRatio: Optional[float] = None, names: Optional[List[str]] = None, ) -> StructType: @@ -1111,10 +1134,10 @@ def _inferSchema( def _createFromRDD( self, - rdd: RDD[Any], + rdd: "RDD[Any]", schema: Optional[Union[DataType, List[str]]], samplingRatio: Optional[float], - ) -> Tuple[RDD[Tuple], StructType]: + ) -> Tuple["RDD[Tuple]", StructType]: """ Create an RDD for DataFrame from an existing RDD, returns the RDD and schema. """ @@ -1146,7 +1169,7 @@ def _createFromRDD( def _createFromLocal( self, data: Iterable[Any], schema: Optional[Union[DataType, List[str]]] - ) -> Tuple[RDD[Tuple], StructType]: + ) -> Tuple["RDD[Tuple]", StructType]: """ Create an RDD for DataFrame from a list or pandas.DataFrame, returns the RDD and schema. @@ -1189,8 +1212,8 @@ def _create_shell_session() -> "SparkSession": that script, which would expose those to users. """ import py4j - from pyspark.conf import SparkConf - from pyspark.context import SparkContext + from pyspark.core.conf import SparkConf + from pyspark.core.context import SparkContext try: # Try to access HiveConf, it will raise exception if Hive is not added @@ -1300,7 +1323,7 @@ def createDataFrame( def createDataFrame( # type: ignore[misc] self, - data: Union[RDD[Any], Iterable[Any], "PandasDataFrameLike", "ArrayLike"], + data: Union["RDD[Any]", Iterable[Any], "PandasDataFrameLike", "ArrayLike"], schema: Optional[Union[AtomicType, StructType, str]] = None, samplingRatio: Optional[float] = None, verifySchema: bool = True, @@ -1515,7 +1538,7 @@ def createDataFrame( # type: ignore[misc] def _create_dataframe( self, - data: Union[RDD[Any], Iterable[Any]], + data: Union["RDD[Any]", Iterable[Any]], schema: Optional[Union[DataType, List[str]]], samplingRatio: Optional[float], verifySchema: bool, @@ -1548,10 +1571,14 @@ def prepare(obj): def prepare(obj: Any) -> Any: return obj - if isinstance(data, RDD): + if not is_remote_only(): + from pyspark.core.rdd import RDD + if not is_remote_only() and isinstance(data, RDD): rdd, struct = self._createFromRDD(data.map(prepare), schema, samplingRatio) else: - rdd, struct = self._createFromLocal(map(prepare, data), schema) + rdd, struct = self._createFromLocal( + map(prepare, data), schema # type: ignore[arg-type] + ) assert self._jvm is not None jrdd = self._jvm.SerDeUtil.toJavaArray(rdd._to_java_object_rdd()) jdf = self._jsparkSession.applySchemaToPythonRDD(jrdd.rdd(), struct.json()) diff --git a/python/pyspark/sql/sql_formatter.py b/python/pyspark/sql/sql_formatter.py index a27f7205a2d7..6d37821e5374 100644 --- a/python/pyspark/sql/sql_formatter.py +++ b/python/pyspark/sql/sql_formatter.py @@ -20,8 +20,6 @@ from typing import Any, Optional, List, Tuple, Sequence, Mapping import uuid -from py4j.java_gateway import is_instance_of - if typing.TYPE_CHECKING: from pyspark.sql import SparkSession, DataFrame from pyspark.sql.functions import lit @@ -47,6 +45,8 @@ def _convert_value(self, val: Any, field_name: str) -> Optional[str]: """ Converts the given value into a SQL string. """ + from py4j.java_gateway import is_instance_of + from pyspark import SparkContext from pyspark.sql import Column, DataFrame diff --git a/python/pyspark/sql/streaming/listener.py b/python/pyspark/sql/streaming/listener.py index c7c962578e2a..c1c9dce04731 100644 --- a/python/pyspark/sql/streaming/listener.py +++ b/python/pyspark/sql/streaming/listener.py @@ -16,16 +16,17 @@ # import uuid import json -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, TYPE_CHECKING from abc import ABC, abstractmethod -from py4j.java_gateway import JavaObject - from pyspark.sql import Row from pyspark import cloudpickle __all__ = ["StreamingQueryListener"] +if TYPE_CHECKING: + from py4j.java_gateway import JavaObject + class StreamingQueryListener(ABC): """ @@ -124,13 +125,13 @@ def onQueryTerminated(self, event: "QueryTerminatedEvent") -> None: pass @property - def _jlistener(self) -> JavaObject: + def _jlistener(self) -> "JavaObject": from pyspark import SparkContext if hasattr(self, "_jlistenerobj"): return self._jlistenerobj - self._jlistenerobj: JavaObject = ( + self._jlistenerobj: "JavaObject" = ( SparkContext._jvm.PythonStreamingQueryListenerWrapper( # type: ignore[union-attr] JStreamingQueryListener(self) ) @@ -146,16 +147,16 @@ class JStreamingQueryListener: def __init__(self, pylistener: StreamingQueryListener) -> None: self.pylistener = pylistener - def onQueryStarted(self, jevent: JavaObject) -> None: + def onQueryStarted(self, jevent: "JavaObject") -> None: self.pylistener.onQueryStarted(QueryStartedEvent.fromJObject(jevent)) - def onQueryProgress(self, jevent: JavaObject) -> None: + def onQueryProgress(self, jevent: "JavaObject") -> None: self.pylistener.onQueryProgress(QueryProgressEvent.fromJObject(jevent)) - def onQueryIdle(self, jevent: JavaObject) -> None: + def onQueryIdle(self, jevent: "JavaObject") -> None: self.pylistener.onQueryIdle(QueryIdleEvent.fromJObject(jevent)) - def onQueryTerminated(self, jevent: JavaObject) -> None: + def onQueryTerminated(self, jevent: "JavaObject") -> None: self.pylistener.onQueryTerminated(QueryTerminatedEvent.fromJObject(jevent)) class Java: @@ -182,7 +183,7 @@ def __init__( self._timestamp: str = timestamp @classmethod - def fromJObject(cls, jevent: JavaObject) -> "QueryStartedEvent": + def fromJObject(cls, jevent: "JavaObject") -> "QueryStartedEvent": return cls( id=uuid.UUID(jevent.id().toString()), runId=uuid.UUID(jevent.runId().toString()), @@ -245,7 +246,7 @@ def __init__(self, progress: "StreamingQueryProgress") -> None: self._progress: StreamingQueryProgress = progress @classmethod - def fromJObject(cls, jevent: JavaObject) -> "QueryProgressEvent": + def fromJObject(cls, jevent: "JavaObject") -> "QueryProgressEvent": return cls(progress=StreamingQueryProgress.fromJObject(jevent.progress())) @classmethod @@ -277,7 +278,7 @@ def __init__(self, id: uuid.UUID, runId: uuid.UUID, timestamp: str) -> None: self._timestamp: str = timestamp @classmethod - def fromJObject(cls, jevent: JavaObject) -> "QueryIdleEvent": + def fromJObject(cls, jevent: "JavaObject") -> "QueryIdleEvent": return cls( id=uuid.UUID(jevent.id().toString()), runId=uuid.UUID(jevent.runId().toString()), @@ -336,7 +337,7 @@ def __init__( self._errorClassOnException: Optional[str] = errorClassOnException @classmethod - def fromJObject(cls, jevent: JavaObject) -> "QueryTerminatedEvent": + def fromJObject(cls, jevent: "JavaObject") -> "QueryTerminatedEvent": jexception = jevent.exception() jerrorclass = jevent.errorClassOnException() return cls( @@ -419,10 +420,10 @@ def __init__( inputRowsPerSecond: float, processedRowsPerSecond: float, observedMetrics: Dict[str, Row], - jprogress: Optional[JavaObject] = None, + jprogress: Optional["JavaObject"] = None, jdict: Optional[Dict[str, Any]] = None, ): - self._jprogress: Optional[JavaObject] = jprogress + self._jprogress: Optional["JavaObject"] = jprogress self._jdict: Optional[Dict[str, Any]] = jdict self._id: uuid.UUID = id self._runId: uuid.UUID = runId @@ -441,7 +442,7 @@ def __init__( self._observedMetrics: Dict[str, Row] = observedMetrics @classmethod - def fromJObject(cls, jprogress: JavaObject) -> "StreamingQueryProgress": + def fromJObject(cls, jprogress: "JavaObject") -> "StreamingQueryProgress": from pyspark import SparkContext return cls( @@ -664,10 +665,10 @@ def __init__( numShufflePartitions: int, numStateStoreInstances: int, customMetrics: Dict[str, int], - jprogress: Optional[JavaObject] = None, + jprogress: Optional["JavaObject"] = None, jdict: Optional[Dict[str, Any]] = None, ): - self._jprogress: Optional[JavaObject] = jprogress + self._jprogress: Optional["JavaObject"] = jprogress self._jdict: Optional[Dict[str, Any]] = jdict self._operatorName: str = operatorName self._numRowsTotal: int = numRowsTotal @@ -683,7 +684,7 @@ def __init__( self._customMetrics: Dict[str, int] = customMetrics @classmethod - def fromJObject(cls, jprogress: JavaObject) -> "StateOperatorProgress": + def fromJObject(cls, jprogress: "JavaObject") -> "StateOperatorProgress": return cls( jprogress=jprogress, operatorName=jprogress.operatorName(), @@ -811,10 +812,10 @@ def __init__( inputRowsPerSecond: float, processedRowsPerSecond: float, metrics: Dict[str, str], - jprogress: Optional[JavaObject] = None, + jprogress: Optional["JavaObject"] = None, jdict: Optional[Dict[str, Any]] = None, ) -> None: - self._jprogress: Optional[JavaObject] = jprogress + self._jprogress: Optional["JavaObject"] = jprogress self._jdict: Optional[Dict[str, Any]] = jdict self._description: str = description self._startOffset: str = startOffset @@ -826,7 +827,7 @@ def __init__( self._metrics: Dict[str, str] = metrics @classmethod - def fromJObject(cls, jprogress: JavaObject) -> "SourceProgress": + def fromJObject(cls, jprogress: "JavaObject") -> "SourceProgress": return cls( jprogress=jprogress, description=jprogress.description(), @@ -946,17 +947,17 @@ def __init__( description: str, numOutputRows: int, metrics: Dict[str, str], - jprogress: Optional[JavaObject] = None, + jprogress: Optional["JavaObject"] = None, jdict: Optional[Dict[str, Any]] = None, ) -> None: - self._jprogress: Optional[JavaObject] = jprogress + self._jprogress: Optional["JavaObject"] = jprogress self._jdict: Optional[Dict[str, Any]] = jdict self._description: str = description self._numOutputRows: int = numOutputRows self._metrics: Dict[str, str] = metrics @classmethod - def fromJObject(cls, jprogress: JavaObject) -> "SinkProgress": + def fromJObject(cls, jprogress: "JavaObject") -> "SinkProgress": return cls( jprogress=jprogress, description=jprogress.description(), diff --git a/python/pyspark/sql/streaming/python_streaming_source_runner.py b/python/pyspark/sql/streaming/python_streaming_source_runner.py index 512191866a16..76f9048e3edb 100644 --- a/python/pyspark/sql/streaming/python_streaming_source_runner.py +++ b/python/pyspark/sql/streaming/python_streaming_source_runner.py @@ -22,7 +22,6 @@ from pyspark.accumulators import _accumulatorRegistry from pyspark.errors import IllegalArgumentException, PySparkAssertionError, PySparkRuntimeError -from pyspark.java_gateway import local_connect_and_auth from pyspark.serializers import ( read_int, write_int, @@ -34,7 +33,7 @@ _parse_datatype_json_string, StructType, ) -from pyspark.util import handle_worker_exception +from pyspark.util import handle_worker_exception, local_connect_and_auth from pyspark.worker_util import ( check_python_version, read_command, diff --git a/python/pyspark/sql/streaming/query.py b/python/pyspark/sql/streaming/query.py index db104e30755a..bcab8a104f1d 100644 --- a/python/pyspark/sql/streaming/query.py +++ b/python/pyspark/sql/streaming/query.py @@ -16,9 +16,7 @@ # import json -from typing import Any, Dict, List, Optional - -from py4j.java_gateway import JavaObject, java_import +from typing import Any, Dict, List, Optional, TYPE_CHECKING from pyspark.errors import StreamingQueryException, PySparkValueError from pyspark.errors.exceptions.captured import ( @@ -26,6 +24,9 @@ ) from pyspark.sql.streaming.listener import StreamingQueryListener +if TYPE_CHECKING: + from py4j.java_gateway import JavaObject + __all__ = ["StreamingQuery", "StreamingQueryManager"] @@ -44,7 +45,7 @@ class StreamingQuery: This API is evolving. """ - def __init__(self, jsq: JavaObject) -> None: + def __init__(self, jsq: "JavaObject") -> None: self._jsq = jsq @property @@ -450,7 +451,7 @@ class StreamingQueryManager: This API is evolving. """ - def __init__(self, jsqm: JavaObject) -> None: + def __init__(self, jsqm: "JavaObject") -> None: self._jsqm = jsqm @property @@ -662,6 +663,7 @@ def addListener(self, listener: StreamingQueryListener) -> None: >>> spark.streams.removeListener(test_listener) """ + from py4j.java_gateway import java_import from pyspark import SparkContext from pyspark.java_gateway import ensure_callback_server_started diff --git a/python/pyspark/sql/streaming/readwriter.py b/python/pyspark/sql/streaming/readwriter.py index 41a83355ab6c..58901f34cfc9 100644 --- a/python/pyspark/sql/streaming/readwriter.py +++ b/python/pyspark/sql/streaming/readwriter.py @@ -19,8 +19,6 @@ from collections.abc import Iterator from typing import cast, overload, Any, Callable, List, Optional, TYPE_CHECKING, Union -from py4j.java_gateway import java_import, JavaObject - from pyspark.sql.column import _to_seq from pyspark.sql.readwriter import OptionUtils, to_str from pyspark.sql.streaming.query import StreamingQuery @@ -34,6 +32,7 @@ ) if TYPE_CHECKING: + from py4j.java_gateway import JavaObject from pyspark.sql.session import SparkSession from pyspark.sql._typing import SupportsProcess, OptionalPrimitiveType from pyspark.sql.dataframe import DataFrame @@ -77,7 +76,7 @@ def __init__(self, spark: "SparkSession") -> None: self._jreader = spark._jsparkSession.readStream() self._spark = spark - def _df(self, jdf: JavaObject) -> "DataFrame": + def _df(self, jdf: "JavaObject") -> "DataFrame": from pyspark.sql.dataframe import DataFrame return DataFrame(jdf, self._spark) @@ -908,7 +907,7 @@ def __init__(self, df: "DataFrame") -> None: self._spark = df.sparkSession self._jwrite = df._jdf.writeStream() - def _sq(self, jsq: JavaObject) -> StreamingQuery: + def _sq(self, jsq: "JavaObject") -> StreamingQuery: return StreamingQuery(jsq) def outputMode(self, outputMode: str) -> "DataStreamWriter": @@ -1489,7 +1488,7 @@ def foreach(self, f: Union[Callable[[Row], None], "SupportsProcess"]) -> "DataSt >>> q.stop() """ - from pyspark.rdd import _wrap_function + from pyspark.core.rdd import _wrap_function from pyspark.serializers import CPickleSerializer, AutoBatchedSerializer func = self._construct_foreach_function(f) @@ -1541,7 +1540,7 @@ def foreachBatch(self, func: Callable[["DataFrame", int], None]) -> "DataStreamW >>> q.stop() >>> # if in Spark Connect, my_value = -1, else my_value = 100 """ - + from py4j.java_gateway import java_import from pyspark.java_gateway import ensure_callback_server_started gw = self._spark._sc._gateway diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf.py b/python/pyspark/sql/tests/pandas/test_pandas_udf.py index 21213c3e7281..82a93524fcf9 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_udf.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_udf.py @@ -22,7 +22,7 @@ from pyspark.sql.functions import udf, pandas_udf, PandasUDFType, assert_true, lit from pyspark.sql.types import DoubleType, StructType, StructField, LongType, DayTimeIntervalType from pyspark.errors import ParseException, PythonException, PySparkTypeError -from pyspark.rdd import PythonEvalType +from pyspark.util import PythonEvalType from pyspark.testing.sqlutils import ( ReusedSQLTestCase, have_pandas, diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py b/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py index 8cc5c2b6aa43..a7cf45e3bcbe 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py @@ -18,7 +18,7 @@ import unittest from typing import cast -from pyspark.rdd import PythonEvalType +from pyspark.util import PythonEvalType from pyspark.sql import Row from pyspark.sql.functions import ( array, diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py b/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py index 44d5c9d1ed94..ec413d048d8e 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py @@ -25,7 +25,7 @@ from typing import cast from pyspark import TaskContext -from pyspark.rdd import PythonEvalType +from pyspark.util import PythonEvalType from pyspark.sql import Column from pyspark.sql.functions import array, col, expr, lit, sum, struct, udf, pandas_udf, PandasUDFType from pyspark.sql.types import ( diff --git a/python/pyspark/sql/tests/test_arrow_python_udf.py b/python/pyspark/sql/tests/test_arrow_python_udf.py index 114fdf602223..23f302ec3c8d 100644 --- a/python/pyspark/sql/tests/test_arrow_python_udf.py +++ b/python/pyspark/sql/tests/test_arrow_python_udf.py @@ -29,7 +29,7 @@ pyarrow_requirement_message, ReusedSQLTestCase, ) -from pyspark.rdd import PythonEvalType +from pyspark.util import PythonEvalType @unittest.skipIf( diff --git a/python/pyspark/sql/tests/test_session.py b/python/pyspark/sql/tests/test_session.py index 5f102d770c6a..ce2c6bf3a1df 100644 --- a/python/pyspark/sql/tests/test_session.py +++ b/python/pyspark/sql/tests/test_session.py @@ -193,8 +193,8 @@ def test_create_new_session_with_statement(self): session.range(5).collect() def test_active_session_with_None_and_not_None_context(self): - from pyspark.context import SparkContext - from pyspark.conf import SparkConf + from pyspark.core.context import SparkContext + from pyspark.core.conf import SparkConf sc = None session = None diff --git a/python/pyspark/sql/tests/test_udtf.py b/python/pyspark/sql/tests/test_udtf.py index d69d710de570..0d2582b51fe1 100644 --- a/python/pyspark/sql/tests/test_udtf.py +++ b/python/pyspark/sql/tests/test_udtf.py @@ -30,8 +30,8 @@ AnalysisException, PySparkPicklingError, ) -from pyspark.files import SparkFiles -from pyspark.rdd import PythonEvalType +from pyspark.core.files import SparkFiles +from pyspark.util import PythonEvalType from pyspark.sql.functions import ( array, create_map, diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 9c4647d3b38a..49342bd21323 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -45,9 +45,7 @@ TYPE_CHECKING, ) -from py4j.protocol import register_input_converter -from py4j.java_gateway import GatewayClient, JavaClass, JavaGateway, JavaObject, JVMView - +from pyspark.util import is_remote_only from pyspark.serializers import CloudPickleSerializer from pyspark.sql.utils import has_numpy, get_active_spark_context from pyspark.errors import ( @@ -63,6 +61,10 @@ if has_numpy: import numpy as np +if TYPE_CHECKING: + import numpy as np + from py4j.java_gateway import GatewayClient, JavaGateway, JavaClass + T = TypeVar("T") U = TypeVar("U") @@ -96,10 +98,6 @@ ] -if TYPE_CHECKING: - import numpy as np - - class DataType: """Base class for data types.""" @@ -1546,6 +1544,8 @@ def _parse_datatype_string(s: str) -> DataType: ... ParseException:... """ + from py4j.java_gateway import JVMView + sc = get_active_spark_context() def from_ddl_schema(type_str: str) -> DataType: @@ -2758,7 +2758,9 @@ class DateConverter: def can_convert(self, obj: Any) -> bool: return isinstance(obj, datetime.date) - def convert(self, obj: datetime.date, gateway_client: GatewayClient) -> JavaObject: + def convert(self, obj: datetime.date, gateway_client: "GatewayClient") -> "JavaGateway": + from py4j.java_gateway import JavaClass + Date = JavaClass("java.sql.Date", gateway_client) return Date.valueOf(obj.strftime("%Y-%m-%d")) @@ -2767,7 +2769,9 @@ class DatetimeConverter: def can_convert(self, obj: Any) -> bool: return isinstance(obj, datetime.datetime) - def convert(self, obj: datetime.datetime, gateway_client: GatewayClient) -> JavaObject: + def convert(self, obj: datetime.datetime, gateway_client: "GatewayClient") -> "JavaGateway": + from py4j.java_gateway import JavaClass + Timestamp = JavaClass("java.sql.Timestamp", gateway_client) seconds = ( calendar.timegm(obj.utctimetuple()) if obj.tzinfo else time.mktime(obj.timetuple()) @@ -2787,7 +2791,9 @@ def can_convert(self, obj: Any) -> bool: and is_timestamp_ntz_preferred() ) - def convert(self, obj: datetime.datetime, gateway_client: GatewayClient) -> JavaObject: + def convert(self, obj: datetime.datetime, gateway_client: "GatewayClient") -> "JavaGateway": + from py4j.java_gateway import JavaClass + seconds = calendar.timegm(obj.utctimetuple()) DateTimeUtils = JavaClass( "org.apache.spark.sql.catalyst.util.DateTimeUtils", @@ -2800,7 +2806,9 @@ class DayTimeIntervalTypeConverter: def can_convert(self, obj: Any) -> bool: return isinstance(obj, datetime.timedelta) - def convert(self, obj: datetime.timedelta, gateway_client: GatewayClient) -> JavaObject: + def convert(self, obj: datetime.timedelta, gateway_client: "GatewayClient") -> "JavaGateway": + from py4j.java_gateway import JavaClass + IntervalUtils = JavaClass( "org.apache.spark.sql.catalyst.util.IntervalUtils", gateway_client, @@ -2814,14 +2822,14 @@ class NumpyScalarConverter: def can_convert(self, obj: Any) -> bool: return has_numpy and isinstance(obj, np.generic) - def convert(self, obj: "np.generic", gateway_client: GatewayClient) -> Any: + def convert(self, obj: "np.generic", gateway_client: "GatewayClient") -> Any: return obj.item() class NumpyArrayConverter: def _from_numpy_type_to_java_type( - self, nt: "np.dtype", gateway: JavaGateway - ) -> Optional[JavaClass]: + self, nt: "np.dtype", gateway: "JavaGateway" + ) -> Optional["JavaClass"]: """Convert NumPy type to Py4J Java type.""" if nt in [np.dtype("int8"), np.dtype("int16")]: # Mapping int8 to gateway.jvm.byte causes @@ -2843,7 +2851,7 @@ def _from_numpy_type_to_java_type( def can_convert(self, obj: Any) -> bool: return has_numpy and isinstance(obj, np.ndarray) and obj.ndim == 1 - def convert(self, obj: "np.ndarray", gateway_client: GatewayClient) -> JavaObject: + def convert(self, obj: "np.ndarray", gateway_client: "GatewayClient") -> "JavaGateway": from pyspark import SparkContext gateway = SparkContext._gateway @@ -2865,15 +2873,18 @@ def convert(self, obj: "np.ndarray", gateway_client: GatewayClient) -> JavaObjec return jarr -# datetime is a subclass of date, we should register DatetimeConverter first -register_input_converter(DatetimeNTZConverter()) -register_input_converter(DatetimeConverter()) -register_input_converter(DateConverter()) -register_input_converter(DayTimeIntervalTypeConverter()) -register_input_converter(NumpyScalarConverter()) -# NumPy array satisfies py4j.java_collections.ListConverter, -# so prepend NumpyArrayConverter -register_input_converter(NumpyArrayConverter(), prepend=True) +if not is_remote_only(): + from py4j.protocol import register_input_converter + + # datetime is a subclass of date, we should register DatetimeConverter first + register_input_converter(DatetimeNTZConverter()) + register_input_converter(DatetimeConverter()) + register_input_converter(DateConverter()) + register_input_converter(DayTimeIntervalTypeConverter()) + register_input_converter(NumpyScalarConverter()) + # NumPy array satisfies py4j.java_collections.ListConverter, + # so prepend NumpyArrayConverter + register_input_converter(NumpyArrayConverter(), prepend=True) def _test() -> None: diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 0324bc678667..0d0fc9042e62 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -25,10 +25,8 @@ import warnings from typing import Callable, Any, TYPE_CHECKING, Optional, cast, Union -from py4j.java_gateway import JavaObject -from pyspark import SparkContext -from pyspark.rdd import _prepare_for_python_RDD, PythonEvalType +from pyspark.util import PythonEvalType from pyspark.sql.column import Column, _to_java_expr, _to_seq from pyspark.sql.types import ( DataType, @@ -42,6 +40,8 @@ from pyspark.errors import PySparkTypeError, PySparkNotImplementedError, PySparkRuntimeError if TYPE_CHECKING: + from py4j.java_gateway import JavaObject + from pyspark.core.context import SparkContext from pyspark.sql._typing import DataTypeOrString, ColumnOrName, UserDefinedFunctionLike from pyspark.sql.session import SparkSession @@ -49,8 +49,10 @@ def _wrap_function( - sc: SparkContext, func: Callable[..., Any], returnType: Optional[DataType] = None -) -> JavaObject: + sc: "SparkContext", func: Callable[..., Any], returnType: Optional[DataType] = None +) -> "JavaObject": + from pyspark.core.rdd import _prepare_for_python_RDD + command: Any if returnType is None: command = func @@ -369,7 +371,7 @@ def returnType(self) -> DataType: return self._returnType_placeholder @property - def _judf(self) -> JavaObject: + def _judf(self) -> "JavaObject": # It is possible that concurrent access, to newly created UDF, # will initialize multiple UserDefinedPythonFunctions. # This is unlikely, doesn't affect correctness, @@ -378,7 +380,7 @@ def _judf(self) -> JavaObject: self._judf_placeholder = self._create_judf(self.func) return self._judf_placeholder - def _create_judf(self, func: Callable[..., Any]) -> JavaObject: + def _create_judf(self, func: Callable[..., Any]) -> "JavaObject": from pyspark.sql import SparkSession spark = SparkSession._getActiveSessionOrCreate() diff --git a/python/pyspark/sql/udtf.py b/python/pyspark/sql/udtf.py index 5bf95277baac..801ecc605e50 100644 --- a/python/pyspark/sql/udtf.py +++ b/python/pyspark/sql/udtf.py @@ -24,16 +24,15 @@ import warnings from typing import Any, Type, TYPE_CHECKING, Optional, Sequence, Union -from py4j.java_gateway import JavaObject - from pyspark.errors import PySparkAttributeError, PySparkPicklingError, PySparkTypeError -from pyspark.rdd import PythonEvalType +from pyspark.util import PythonEvalType from pyspark.sql.column import _to_java_column, _to_java_expr, _to_seq from pyspark.sql.pandas.utils import require_minimum_pandas_version, require_minimum_pyarrow_version from pyspark.sql.types import DataType, StructType, _parse_datatype_string from pyspark.sql.udf import _wrap_function if TYPE_CHECKING: + from py4j.java_gateway import JavaObject from pyspark.sql._typing import ColumnOrName from pyspark.sql.dataframe import DataFrame from pyspark.sql.session import SparkSession @@ -328,12 +327,12 @@ def returnType(self) -> Optional[StructType]: return self._returnType_placeholder @property - def _judtf(self) -> JavaObject: + def _judtf(self) -> "JavaObject": if self._judtf_placeholder is None: self._judtf_placeholder = self._create_judtf(self.func) return self._judtf_placeholder - def _create_judtf(self, func: Type) -> JavaObject: + def _create_judtf(self, func: Type) -> "JavaObject": from pyspark.sql import SparkSession spark = SparkSession._getActiveSessionOrCreate() diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py index 8d05fa54d270..09ad959e2b8e 100644 --- a/python/pyspark/sql/utils.py +++ b/python/pyspark/sql/utils.py @@ -19,15 +19,6 @@ import os from typing import Any, Callable, Optional, Sequence, TYPE_CHECKING, cast, TypeVar, Union, Type -from py4j.java_collections import JavaArray -from py4j.java_gateway import ( - JavaClass, - JavaGateway, - JavaObject, -) - -from pyspark import SparkContext - # For backward compatibility. from pyspark.errors import ( # noqa: F401 AnalysisException, @@ -41,10 +32,18 @@ PySparkNotImplementedError, PySparkRuntimeError, ) +from pyspark.util import is_remote_only from pyspark.errors.exceptions.captured import CapturedException # noqa: F401 from pyspark.find_spark_home import _find_spark_home if TYPE_CHECKING: + from py4j.java_collections import JavaArray + from py4j.java_gateway import ( + JavaClass, + JavaGateway, + JavaObject, + ) + from pyspark import SparkContext from pyspark.sql.session import SparkSession from pyspark.sql.dataframe import DataFrame from pyspark.sql.column import Column @@ -63,7 +62,7 @@ FuncT = TypeVar("FuncT", bound=Callable[..., Any]) -def toJArray(gateway: JavaGateway, jtype: JavaClass, arr: Sequence[Any]) -> JavaArray: +def toJArray(gateway: "JavaGateway", jtype: "JavaClass", arr: Sequence[Any]) -> "JavaArray": """ Convert python list to java type array @@ -76,7 +75,7 @@ def toJArray(gateway: JavaGateway, jtype: JavaClass, arr: Sequence[Any]) -> Java arr : python type list """ - jarray: JavaArray = gateway.new_array(jtype, len(arr)) + jarray: "JavaArray" = gateway.new_array(jtype, len(arr)) for i in range(0, len(arr)): jarray[i] = arr[i] return jarray @@ -108,7 +107,7 @@ def __init__(self, session: "SparkSession", func: Callable[["DataFrame", int], N self.func = func self.session = session - def call(self, jdf: JavaObject, batch_id: int) -> None: + def call(self, jdf: "JavaObject", batch_id: int) -> None: from pyspark.sql.dataframe import DataFrame from pyspark.sql.session import SparkSession @@ -151,6 +150,8 @@ def is_timestamp_ntz_preferred() -> bool: else: return session.conf.get("spark.sql.timestampType", None) == "TIMESTAMP_NTZ" else: + from pyspark import SparkContext + jvm = SparkContext._jvm return jvm is not None and jvm.PythonSQLUtils.isTimestampNTZPreferred() @@ -178,7 +179,7 @@ def is_remote() -> bool: >>> is_remote() False """ - return "SPARK_CONNECT_MODE_ENABLED" in os.environ + return ("SPARK_CONNECT_MODE_ENABLED" in os.environ) or is_remote_only() def try_remote_functions(f: FuncT) -> FuncT: @@ -271,9 +272,11 @@ def wrapped(*args: Any, **kwargs: Any) -> Any: return cast(FuncT, wrapped) -def get_active_spark_context() -> SparkContext: +def get_active_spark_context() -> "SparkContext": """Raise RuntimeError if SparkContext is not initialized, otherwise, returns the active SparkContext.""" + from pyspark import SparkContext + sc = SparkContext._active_spark_context if sc is None or sc._jvm is None: raise PySparkRuntimeError( diff --git a/python/pyspark/sql/window.py b/python/pyspark/sql/window.py index ca05cb0cc7fd..42d50dc1b3bd 100644 --- a/python/pyspark/sql/window.py +++ b/python/pyspark/sql/window.py @@ -17,8 +17,6 @@ import sys from typing import cast, Iterable, List, Tuple, TYPE_CHECKING, Union -from py4j.java_gateway import JavaObject, JVMView - from pyspark.sql.column import _to_seq, _to_java_column from pyspark.sql.utils import ( try_remote_window, @@ -27,12 +25,13 @@ ) if TYPE_CHECKING: + from py4j.java_gateway import JavaObject from pyspark.sql._typing import ColumnOrName, ColumnOrName_ __all__ = ["Window", "WindowSpec"] -def _to_java_cols(cols: Tuple[Union["ColumnOrName", List["ColumnOrName_"]], ...]) -> JavaObject: +def _to_java_cols(cols: Tuple[Union["ColumnOrName", List["ColumnOrName_"]], ...]) -> "JavaObject": if len(cols) == 1 and isinstance(cols[0], list): cols = cols[0] # type: ignore[assignment] sc = get_active_spark_context() @@ -125,6 +124,8 @@ def partitionBy(*cols: Union["ColumnOrName", List["ColumnOrName_"]]) -> "WindowS | 3| b| 3| +---+--------+----------+ """ + from py4j.java_gateway import JVMView + sc = get_active_spark_context() jspec = cast(JVMView, sc._jvm).org.apache.spark.sql.expressions.Window.partitionBy( _to_java_cols(cols) @@ -182,6 +183,8 @@ def orderBy(*cols: Union["ColumnOrName", List["ColumnOrName_"]]) -> "WindowSpec" | 3| b| 1| +---+--------+----------+ """ + from py4j.java_gateway import JVMView + sc = get_active_spark_context() jspec = cast(JVMView, sc._jvm).org.apache.spark.sql.expressions.Window.orderBy( _to_java_cols(cols) @@ -263,6 +266,8 @@ def rowsBetween(start: int, end: int) -> "WindowSpec": +---+--------+---+ """ + from py4j.java_gateway import JVMView + if start <= Window._PRECEDING_THRESHOLD: start = Window.unboundedPreceding if end >= Window._FOLLOWING_THRESHOLD: @@ -351,6 +356,8 @@ def rangeBetween(start: int, end: int) -> "WindowSpec": +---+--------+---+ """ + from py4j.java_gateway import JVMView + if start <= Window._PRECEDING_THRESHOLD: start = Window.unboundedPreceding if end >= Window._FOLLOWING_THRESHOLD: @@ -375,7 +382,7 @@ class WindowSpec: Supports Spark Connect. """ - def __init__(self, jspec: JavaObject) -> None: + def __init__(self, jspec: "JavaObject") -> None: self._jspec = jspec @try_remote_windowspec diff --git a/python/pyspark/sql/worker/analyze_udtf.py b/python/pyspark/sql/worker/analyze_udtf.py index b04fd155bb53..d0a24363c0c1 100644 --- a/python/pyspark/sql/worker/analyze_udtf.py +++ b/python/pyspark/sql/worker/analyze_udtf.py @@ -23,7 +23,6 @@ from pyspark.accumulators import _accumulatorRegistry from pyspark.errors import PySparkRuntimeError, PySparkValueError -from pyspark.java_gateway import local_connect_and_auth from pyspark.serializers import ( read_bool, read_int, @@ -34,7 +33,7 @@ from pyspark.sql.functions import OrderingColumn, PartitioningColumn, SelectedColumn from pyspark.sql.types import _parse_datatype_json_string, StructType from pyspark.sql.udtf import AnalyzeArgument, AnalyzeResult -from pyspark.util import handle_worker_exception +from pyspark.util import handle_worker_exception, local_connect_and_auth from pyspark.worker_util import ( check_python_version, read_command, diff --git a/python/pyspark/sql/worker/commit_data_source_write.py b/python/pyspark/sql/worker/commit_data_source_write.py index afba7d467854..530f18ef8288 100644 --- a/python/pyspark/sql/worker/commit_data_source_write.py +++ b/python/pyspark/sql/worker/commit_data_source_write.py @@ -20,7 +20,6 @@ from pyspark.accumulators import _accumulatorRegistry from pyspark.errors import PySparkAssertionError -from pyspark.java_gateway import local_connect_and_auth from pyspark.serializers import ( read_bool, read_int, @@ -28,7 +27,7 @@ SpecialLengths, ) from pyspark.sql.datasource import DataSourceWriter, WriterCommitMessage -from pyspark.util import handle_worker_exception +from pyspark.util import handle_worker_exception, local_connect_and_auth from pyspark.worker_util import ( check_python_version, pickleSer, diff --git a/python/pyspark/sql/worker/create_data_source.py b/python/pyspark/sql/worker/create_data_source.py index 565b2cb187e2..1f11b65f44c7 100644 --- a/python/pyspark/sql/worker/create_data_source.py +++ b/python/pyspark/sql/worker/create_data_source.py @@ -21,7 +21,6 @@ from pyspark.accumulators import _accumulatorRegistry from pyspark.errors import PySparkAssertionError, PySparkRuntimeError, PySparkTypeError -from pyspark.java_gateway import local_connect_and_auth from pyspark.serializers import ( read_bool, read_int, @@ -31,7 +30,7 @@ ) from pyspark.sql.datasource import DataSource, CaseInsensitiveDict from pyspark.sql.types import _parse_datatype_json_string, StructType -from pyspark.util import handle_worker_exception +from pyspark.util import handle_worker_exception, local_connect_and_auth from pyspark.worker_util import ( check_python_version, read_command, diff --git a/python/pyspark/sql/worker/lookup_data_sources.py b/python/pyspark/sql/worker/lookup_data_sources.py index 91963658ee61..7f0127b71946 100644 --- a/python/pyspark/sql/worker/lookup_data_sources.py +++ b/python/pyspark/sql/worker/lookup_data_sources.py @@ -21,7 +21,6 @@ from typing import IO from pyspark.accumulators import _accumulatorRegistry -from pyspark.java_gateway import local_connect_and_auth from pyspark.serializers import ( read_int, write_int, @@ -29,7 +28,7 @@ SpecialLengths, ) from pyspark.sql.datasource import DataSource -from pyspark.util import handle_worker_exception +from pyspark.util import handle_worker_exception, local_connect_and_auth from pyspark.worker_util import ( check_python_version, pickleSer, diff --git a/python/pyspark/sql/worker/plan_data_source_read.py b/python/pyspark/sql/worker/plan_data_source_read.py index 3e5105996ed4..6c0d48caefeb 100644 --- a/python/pyspark/sql/worker/plan_data_source_read.py +++ b/python/pyspark/sql/worker/plan_data_source_read.py @@ -23,7 +23,6 @@ from pyspark.accumulators import _accumulatorRegistry from pyspark.errors import PySparkAssertionError, PySparkRuntimeError -from pyspark.java_gateway import local_connect_and_auth from pyspark.serializers import ( read_bool, read_int, @@ -39,7 +38,7 @@ BinaryType, StructType, ) -from pyspark.util import handle_worker_exception +from pyspark.util import handle_worker_exception, local_connect_and_auth from pyspark.worker_util import ( check_python_version, read_command, diff --git a/python/pyspark/sql/worker/python_streaming_sink_runner.py b/python/pyspark/sql/worker/python_streaming_sink_runner.py index d4f81da5aceb..ba0a8037de60 100644 --- a/python/pyspark/sql/worker/python_streaming_sink_runner.py +++ b/python/pyspark/sql/worker/python_streaming_sink_runner.py @@ -21,7 +21,7 @@ from pyspark.accumulators import _accumulatorRegistry from pyspark.errors import PySparkAssertionError, PySparkRuntimeError -from pyspark.java_gateway import local_connect_and_auth +from pyspark.util import local_connect_and_auth from pyspark.serializers import ( read_bool, read_int, diff --git a/python/pyspark/sql/worker/write_into_data_source.py b/python/pyspark/sql/worker/write_into_data_source.py index 490ede9ab0f2..ad8717cb33b5 100644 --- a/python/pyspark/sql/worker/write_into_data_source.py +++ b/python/pyspark/sql/worker/write_into_data_source.py @@ -22,7 +22,6 @@ from pyspark.accumulators import _accumulatorRegistry from pyspark.sql.connect.conversion import ArrowTableToRowsConversion from pyspark.errors import PySparkAssertionError, PySparkRuntimeError, PySparkTypeError -from pyspark.java_gateway import local_connect_and_auth from pyspark.serializers import ( read_bool, read_int, @@ -37,7 +36,7 @@ BinaryType, _create_row, ) -from pyspark.util import handle_worker_exception +from pyspark.util import handle_worker_exception, local_connect_and_auth from pyspark.worker_util import ( check_python_version, read_command, diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py index 84e7cd7fcc66..bb0a659a6b33 100644 --- a/python/pyspark/streaming/context.py +++ b/python/pyspark/streaming/context.py @@ -20,7 +20,7 @@ from pyspark import RDD, SparkConf from pyspark.serializers import NoOpSerializer, UTF8Deserializer, CloudPickleSerializer -from pyspark.context import SparkContext +from pyspark.core.context import SparkContext from pyspark.storagelevel import StorageLevel from pyspark.streaming.dstream import DStream from pyspark.streaming.listener import StreamingListener diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py index e8b3e4dd455d..145d1fff0e39 100644 --- a/python/pyspark/streaming/dstream.py +++ b/python/pyspark/streaming/dstream.py @@ -40,7 +40,7 @@ from pyspark.storagelevel import StorageLevel from pyspark.streaming.util import rddToFileName, TransformFunction -from pyspark.rdd import portable_hash, RDD +from pyspark.core.rdd import portable_hash, RDD from pyspark.resultiterable import ResultIterable if TYPE_CHECKING: diff --git a/python/pyspark/taskcontext.py b/python/pyspark/taskcontext.py index 329820c3ee01..829d6f628e9f 100644 --- a/python/pyspark/taskcontext.py +++ b/python/pyspark/taskcontext.py @@ -14,13 +14,15 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from typing import ClassVar, Type, Dict, List, Optional, Union, cast +from typing import ClassVar, Type, Dict, List, Optional, Union, cast, TYPE_CHECKING -from pyspark.java_gateway import local_connect_and_auth -from pyspark.resource import ResourceInformation +from pyspark.util import local_connect_and_auth from pyspark.serializers import read_int, write_int, write_with_length, UTF8Deserializer from pyspark.errors import PySparkRuntimeError +if TYPE_CHECKING: + from pyspark.resource import ResourceInformation + class TaskContext: @@ -127,7 +129,7 @@ class TaskContext: _taskAttemptId: Optional[int] = None _localProperties: Optional[Dict[str, str]] = None _cpus: Optional[int] = None - _resources: Optional[Dict[str, ResourceInformation]] = None + _resources: Optional[Dict[str, "ResourceInformation"]] = None def __new__(cls: Type["TaskContext"]) -> "TaskContext": """ @@ -240,7 +242,7 @@ def cpus(self) -> int: """ return cast(int, self._cpus) - def resources(self) -> Dict[str, ResourceInformation]: + def resources(self) -> Dict[str, "ResourceInformation"]: """ Resources allocated to the task. The key is the resource name and the value is information about the resource. @@ -250,7 +252,9 @@ def resources(self) -> Dict[str, ResourceInformation]: dict a dictionary of a string resource name, and :class:`ResourceInformation`. """ - return cast(Dict[str, ResourceInformation], self._resources) + from pyspark.resource import ResourceInformation + + return cast(Dict[str, "ResourceInformation"], self._resources) BARRIER_FUNCTION = 1 diff --git a/python/pyspark/tests/typing/test_rdd.yml b/python/pyspark/tests/typing/test_rdd.yml index 5207b1cd1ac1..b0b18db6a480 100644 --- a/python/pyspark/tests/typing/test_rdd.yml +++ b/python/pyspark/tests/typing/test_rdd.yml @@ -100,12 +100,12 @@ reveal_type(sc.parallelize([("a", 1)]).aggregateByKey(zero, seq_func, comb_func)) out: | - main:11: note: Revealed type is "pyspark.rdd.RDD[builtins.str]" - main:16: note: Revealed type is "pyspark.rdd.RDD[builtins.int]" - main:18: note: Revealed type is "pyspark.rdd.RDD[tuple[builtins.str, builtins.int]]" + main:11: note: Revealed type is "pyspark.core.rdd.RDD[builtins.str]" + main:16: note: Revealed type is "pyspark.core.rdd.RDD[builtins.int]" + main:18: note: Revealed type is "pyspark.core.rdd.RDD[tuple[builtins.str, builtins.int]]" main:20: note: Revealed type is "tuple[builtins.str, builtins.int]" main:22: note: Revealed type is "builtins.int" - main:34: note: Revealed type is "pyspark.rdd.RDD[tuple[builtins.str, builtins.set[builtins.str]]]" + main:34: note: Revealed type is "pyspark.core.rdd.RDD[tuple[builtins.str, builtins.set[builtins.str]]]" - case: rddMethodsErrors main: | diff --git a/python/pyspark/traceback_utils.py b/python/pyspark/traceback_utils.py index af4169e7d89e..e7f1b373e2dd 100644 --- a/python/pyspark/traceback_utils.py +++ b/python/pyspark/traceback_utils.py @@ -51,7 +51,7 @@ class SCCallSiteSync: Helper for setting the spark context call site. Example usage: - from pyspark.context import SCCallSiteSync + from pyspark.core.context import SCCallSiteSync with SCCallSiteSync() as css: """ diff --git a/python/pyspark/util.py b/python/pyspark/util.py index ec9c2489b41e..bf1cf5b59553 100644 --- a/python/pyspark/util.py +++ b/python/pyspark/util.py @@ -26,18 +26,48 @@ import threading import traceback import typing +import socket from types import TracebackType from typing import Any, Callable, IO, Iterator, List, Optional, TextIO, Tuple, Union -from py4j.clientserver import ClientServer - from pyspark.errors import PySparkRuntimeError +from pyspark.serializers import ( + write_int, + read_int, + write_with_length, + SpecialLengths, + UTF8Deserializer, +) __all__: List[str] = [] -from py4j.java_gateway import JavaObject - if typing.TYPE_CHECKING: + import io + + from py4j.java_collections import JavaArray + from py4j.java_gateway import JavaObject + + from pyspark._typing import NonUDFType + from pyspark.sql.pandas._typing import ( + PandasScalarUDFType, + PandasGroupedMapUDFType, + PandasGroupedAggUDFType, + PandasWindowAggUDFType, + PandasScalarIterUDFType, + PandasMapIterUDFType, + PandasCogroupedMapUDFType, + ArrowMapIterUDFType, + PandasGroupedMapUDFWithStateType, + ArrowGroupedMapUDFType, + ArrowCogroupedMapUDFType, + ) + from pyspark.sql._typing import ( + SQLArrowBatchedUDFType, + SQLArrowTableUDFType, + SQLBatchedUDFType, + SQLTableUDFType, + ) + from pyspark.serializers import Serializer from pyspark.sql import SparkSession @@ -365,6 +395,7 @@ def inner(*args: Any, **kwargs: Any) -> Any: # Non Spark Connect from pyspark import SparkContext + from py4j.clientserver import ClientServer if isinstance(SparkContext._gateway, ClientServer): # Here's when the pinned-thread mode (PYSPARK_PIN_THREAD) is on. @@ -394,8 +425,6 @@ def handle_worker_exception(e: BaseException, outfile: IO) -> None: and exception traceback info to outfile. JVM could then read from the outfile and perform exception handling there. """ - from pyspark.serializers import write_int, write_with_length, SpecialLengths - try: exc_info = None if os.environ.get("SPARK_SIMPLIFIED_TRACEBACK", False): @@ -437,7 +466,7 @@ class InheritableThread(threading.Thread): This API is experimental. """ - _props: JavaObject + _props: "JavaObject" def __init__( self, target: Callable, *args: Any, session: Optional["SparkSession"] = None, **kwargs: Any @@ -461,6 +490,7 @@ def copy_local_properties(*a: Any, **k: Any) -> Any: else: # Non Spark Connect from pyspark import SparkContext + from py4j.clientserver import ClientServer if isinstance(SparkContext._gateway, ClientServer): # Here's when the pinned-thread mode (PYSPARK_PIN_THREAD) is on. @@ -491,6 +521,7 @@ def start(self) -> None: else: # Non Spark Connect from pyspark import SparkContext + from py4j.clientserver import ClientServer if isinstance(SparkContext._gateway, ClientServer): # Here's when the pinned-thread mode (PYSPARK_PIN_THREAD) is on. @@ -503,11 +534,236 @@ def start(self) -> None: return super(InheritableThread, self).start() +class PythonEvalType: + """ + Evaluation type of python rdd. + + These values are internal to PySpark. + + These values should match values in org.apache.spark.api.python.PythonEvalType. + """ + + NON_UDF: "NonUDFType" = 0 + + SQL_BATCHED_UDF: "SQLBatchedUDFType" = 100 + SQL_ARROW_BATCHED_UDF: "SQLArrowBatchedUDFType" = 101 + + SQL_SCALAR_PANDAS_UDF: "PandasScalarUDFType" = 200 + SQL_GROUPED_MAP_PANDAS_UDF: "PandasGroupedMapUDFType" = 201 + SQL_GROUPED_AGG_PANDAS_UDF: "PandasGroupedAggUDFType" = 202 + SQL_WINDOW_AGG_PANDAS_UDF: "PandasWindowAggUDFType" = 203 + SQL_SCALAR_PANDAS_ITER_UDF: "PandasScalarIterUDFType" = 204 + SQL_MAP_PANDAS_ITER_UDF: "PandasMapIterUDFType" = 205 + SQL_COGROUPED_MAP_PANDAS_UDF: "PandasCogroupedMapUDFType" = 206 + SQL_MAP_ARROW_ITER_UDF: "ArrowMapIterUDFType" = 207 + SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE: "PandasGroupedMapUDFWithStateType" = 208 + SQL_GROUPED_MAP_ARROW_UDF: "ArrowGroupedMapUDFType" = 209 + SQL_COGROUPED_MAP_ARROW_UDF: "ArrowCogroupedMapUDFType" = 210 + + SQL_TABLE_UDF: "SQLTableUDFType" = 300 + SQL_ARROW_TABLE_UDF: "SQLArrowTableUDFType" = 301 + + +def _create_local_socket(sock_info: "JavaArray") -> "io.BufferedRWPair": + """ + Create a local socket that can be used to load deserialized data from the JVM + + Parameters + ---------- + sock_info : tuple + Tuple containing port number and authentication secret for a local socket. + + Returns + ------- + sockfile file descriptor of the local socket + """ + sockfile: "io.BufferedRWPair" + sock: "socket.socket" + port: int = sock_info[0] + auth_secret: str = sock_info[1] + sockfile, sock = local_connect_and_auth(port, auth_secret) + # The RDD materialization time is unpredictable, if we set a timeout for socket reading + # operation, it will very possibly fail. See SPARK-18281. + sock.settimeout(None) + return sockfile + + +def _load_from_socket(sock_info: "JavaArray", serializer: "Serializer") -> Iterator[Any]: + """ + Connect to a local socket described by sock_info and use the given serializer to yield data + + Parameters + ---------- + sock_info : tuple + Tuple containing port number and authentication secret for a local socket. + serializer : class:`Serializer` + The PySpark serializer to use + + Returns + ------- + result of meth:`Serializer.load_stream`, + usually a generator that yields deserialized data + """ + sockfile = _create_local_socket(sock_info) + # The socket will be automatically closed when garbage-collected. + return serializer.load_stream(sockfile) + + +def _local_iterator_from_socket(sock_info: "JavaArray", serializer: "Serializer") -> Iterator[Any]: + class PyLocalIterable: + """Create a synchronous local iterable over a socket""" + + def __init__(self, _sock_info: "JavaArray", _serializer: "Serializer"): + port: int + auth_secret: str + jsocket_auth_server: "JavaObject" + port, auth_secret, self.jsocket_auth_server = _sock_info + self._sockfile = _create_local_socket((port, auth_secret)) + self._serializer = _serializer + self._read_iter: Iterator[Any] = iter([]) # Initialize as empty iterator + self._read_status = 1 + + def __iter__(self) -> Iterator[Any]: + while self._read_status == 1: + # Request next partition data from Java + write_int(1, self._sockfile) + self._sockfile.flush() + + # If response is 1 then there is a partition to read, if 0 then fully consumed + self._read_status = read_int(self._sockfile) + if self._read_status == 1: + # Load the partition data as a stream and read each item + self._read_iter = self._serializer.load_stream(self._sockfile) + for item in self._read_iter: + yield item + + # An error occurred, join serving thread and raise any exceptions from the JVM + elif self._read_status == -1: + self.jsocket_auth_server.getResult() + + def __del__(self) -> None: + # If local iterator is not fully consumed, + if self._read_status == 1: + try: + # Finish consuming partition data stream + for _ in self._read_iter: + pass + # Tell Java to stop sending data and close connection + write_int(0, self._sockfile) + self._sockfile.flush() + except Exception: + # Ignore any errors, socket is automatically closed when garbage-collected + pass + + return iter(PyLocalIterable(sock_info, serializer)) + + +def local_connect_and_auth(port: Optional[Union[str, int]], auth_secret: str) -> Tuple: + """ + Connect to local host, authenticate with it, and return a (sockfile,sock) for that connection. + Handles IPV4 & IPV6, does some error handling. + + Parameters + ---------- + port : str or int, optional + auth_secret : str + + Returns + ------- + tuple + with (sockfile, sock) + """ + sock = None + errors = [] + # Support for both IPv4 and IPv6. + addr = "127.0.0.1" + if os.environ.get("SPARK_PREFER_IPV6", "false").lower() == "true": + addr = "::1" + for res in socket.getaddrinfo(addr, port, socket.AF_UNSPEC, socket.SOCK_STREAM): + af, socktype, proto, _, sa = res + try: + sock = socket.socket(af, socktype, proto) + sock.settimeout(int(os.environ.get("SPARK_AUTH_SOCKET_TIMEOUT", 15))) + sock.connect(sa) + sockfile = sock.makefile("rwb", int(os.environ.get("SPARK_BUFFER_SIZE", 65536))) + _do_server_auth(sockfile, auth_secret) + return (sockfile, sock) + except socket.error as e: + emsg = str(e) + errors.append("tried to connect to %s, but an error occurred: %s" % (sa, emsg)) + if sock is not None: + sock.close() + sock = None + raise PySparkRuntimeError( + error_class="CANNOT_OPEN_SOCKET", + message_parameters={ + "errors": str(errors), + }, + ) + + +def _do_server_auth(conn: "io.IOBase", auth_secret: str) -> None: + """ + Performs the authentication protocol defined by the SocketAuthHelper class on the given + file-like object 'conn'. + """ + write_with_length(auth_secret.encode("utf-8"), conn) + conn.flush() + reply = UTF8Deserializer().loads(conn) + if reply != "ok": + conn.close() + raise PySparkRuntimeError( + error_class="UNEXPECTED_RESPONSE_FROM_SERVER", + message_parameters={}, + ) + + +_is_remote_only = None + + +def is_remote_only() -> bool: + """ + Returns if the current running environment is only for Spark Connect. + If users install pyspark-connect alone, RDD API does not exist. + + .. versionadded:: 4.0.0 + + Notes + ----- + This will only return ``True`` if installed PySpark is only for Spark Connect. + Otherwise, it returns ``False``. + + This API is unstable, and for developers. + + Returns + ------- + bool + + Examples + -------- + >>> from pyspark.sql import is_remote + >>> is_remote() + False + """ + global _is_remote_only + + if _is_remote_only is not None: + return _is_remote_only + try: + from pyspark import core # noqa: F401 + + _is_remote_only = False + return _is_remote_only + except ImportError: + _is_remote_only = True + return _is_remote_only + + if __name__ == "__main__": if "pypy" not in platform.python_implementation().lower() and sys.version_info[:2] >= (3, 7): import doctest import pyspark.util - from pyspark.context import SparkContext + from pyspark.core.context import SparkContext globs = pyspark.util.__dict__.copy() globs["sc"] = SparkContext("local[4]", "PythonTest") diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 7ce4c17edf54..41f6c35bc445 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -32,10 +32,9 @@ _accumulatorRegistry, _deserialize_accumulator, ) -from pyspark.java_gateway import local_connect_and_auth from pyspark.taskcontext import BarrierTaskContext, TaskContext from pyspark.resource import ResourceInformation -from pyspark.rdd import PythonEvalType +from pyspark.util import PythonEvalType, local_connect_and_auth from pyspark.serializers import ( write_int, read_long, diff --git a/python/pyspark/worker_util.py b/python/pyspark/worker_util.py index 9ee18758ce42..f3c59c91ea2c 100644 --- a/python/pyspark/worker_util.py +++ b/python/pyspark/worker_util.py @@ -33,10 +33,10 @@ has_resource_module = False from pyspark.accumulators import _accumulatorRegistry -from pyspark.broadcast import Broadcast, _broadcastRegistry +from pyspark.core.broadcast import Broadcast, _broadcastRegistry from pyspark.errors import PySparkRuntimeError -from pyspark.files import SparkFiles -from pyspark.java_gateway import local_connect_and_auth +from pyspark.core.files import SparkFiles +from pyspark.util import local_connect_and_auth from pyspark.serializers import ( read_bool, read_int,