diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index 459b05cc37aa..17d50f0f50e9 100644 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -18,14 +18,18 @@ import unittest import shutil import tempfile +from pyspark.testing.sqlutils import have_pandas -import pandas +if have_pandas: + import pandas from pyspark.sql import SparkSession, Row from pyspark.sql.types import StructType, StructField, LongType, StringType -from pyspark.sql.connect.client import RemoteSparkSession -from pyspark.sql.connect.function_builder import udf -from pyspark.sql.connect.functions import lit + +if have_pandas: + from pyspark.sql.connect.client import RemoteSparkSession + from pyspark.sql.connect.function_builder import udf + from pyspark.sql.connect.functions import lit from pyspark.sql.dataframe import DataFrame from pyspark.testing.connectutils import should_test_connect, connect_requirement_message from pyspark.testing.utils import ReusedPySparkTestCase @@ -36,7 +40,8 @@ class SparkConnectSQLTestCase(ReusedPySparkTestCase): """Parent test fixture class for all Spark Connect related test cases.""" - connect: RemoteSparkSession + if have_pandas: + connect: RemoteSparkSession tbl_name: str df_text: "DataFrame" diff --git a/python/pyspark/sql/tests/connect/test_connect_column_expressions.py b/python/pyspark/sql/tests/connect/test_connect_column_expressions.py index 6036b63d76f2..790a987e8809 100644 --- a/python/pyspark/sql/tests/connect/test_connect_column_expressions.py +++ b/python/pyspark/sql/tests/connect/test_connect_column_expressions.py @@ -15,14 +15,20 @@ # limitations under the License. # +from typing import cast +import unittest from pyspark.testing.connectutils import PlanOnlyTestFixture -from pyspark.sql.connect.proto import Expression as ProtoExpression -import pyspark.sql.connect as c -import pyspark.sql.connect.plan as p -import pyspark.sql.connect.column as col -import pyspark.sql.connect.functions as fun +from pyspark.testing.sqlutils import have_pandas, pandas_requirement_message +if have_pandas: + from pyspark.sql.connect.proto import Expression as ProtoExpression + import pyspark.sql.connect as c + import pyspark.sql.connect.plan as p + import pyspark.sql.connect.column as col + import pyspark.sql.connect.functions as fun + +@unittest.skipIf(not have_pandas, cast(str, pandas_requirement_message)) class SparkConnectColumnExpressionSuite(PlanOnlyTestFixture): def test_simple_column_expressions(self): df = c.DataFrame.withPlan(p.Read("table")) diff --git a/python/pyspark/sql/tests/connect/test_connect_plan_only.py b/python/pyspark/sql/tests/connect/test_connect_plan_only.py index 450f5c70faba..14b939e019ba 100644 --- a/python/pyspark/sql/tests/connect/test_connect_plan_only.py +++ b/python/pyspark/sql/tests/connect/test_connect_plan_only.py @@ -14,15 +14,20 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from typing import cast import unittest from pyspark.testing.connectutils import PlanOnlyTestFixture -import pyspark.sql.connect.proto as proto -from pyspark.sql.connect.readwriter import DataFrameReader -from pyspark.sql.connect.function_builder import UserDefinedFunction, udf -from pyspark.sql.types import StringType +from pyspark.testing.sqlutils import have_pandas, pandas_requirement_message +if have_pandas: + import pyspark.sql.connect.proto as proto + from pyspark.sql.connect.readwriter import DataFrameReader + from pyspark.sql.connect.function_builder import UserDefinedFunction, udf + from pyspark.sql.types import StringType + +@unittest.skipIf(not have_pandas, cast(str, pandas_requirement_message)) class SparkConnectTestsPlanOnly(PlanOnlyTestFixture): """These test cases exercise the interface to the proto plan generation but do not call Spark.""" diff --git a/python/pyspark/sql/tests/connect/test_connect_select_ops.py b/python/pyspark/sql/tests/connect/test_connect_select_ops.py index e89b4b34ea01..a29c70541462 100644 --- a/python/pyspark/sql/tests/connect/test_connect_select_ops.py +++ b/python/pyspark/sql/tests/connect/test_connect_select_ops.py @@ -14,13 +14,20 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from typing import cast +import unittest + from pyspark.testing.connectutils import PlanOnlyTestFixture -from pyspark.sql.connect import DataFrame -from pyspark.sql.connect.functions import col -from pyspark.sql.connect.plan import Read -import pyspark.sql.connect.proto as proto +from pyspark.testing.sqlutils import have_pandas, pandas_requirement_message + +if have_pandas: + from pyspark.sql.connect import DataFrame + from pyspark.sql.connect.functions import col + from pyspark.sql.connect.plan import Read + import pyspark.sql.connect.proto as proto +@unittest.skipIf(not have_pandas, cast(str, pandas_requirement_message)) class SparkConnectToProtoSuite(PlanOnlyTestFixture): def test_select_with_columns_and_strings(self): df = DataFrame.withPlan(Read("table")) diff --git a/python/pyspark/testing/connectutils.py b/python/pyspark/testing/connectutils.py index 700b7bb72e18..d9bced3af114 100644 --- a/python/pyspark/testing/connectutils.py +++ b/python/pyspark/testing/connectutils.py @@ -18,13 +18,18 @@ from typing import Any, Dict import functools import unittest +from pyspark.testing.sqlutils import have_pandas -from pyspark.sql.connect import DataFrame -from pyspark.sql.connect.plan import Read -from pyspark.testing.utils import search_jar +if have_pandas: + from pyspark.sql.connect import DataFrame + from pyspark.sql.connect.plan import Read + from pyspark.testing.utils import search_jar + + connect_jar = search_jar("connector/connect", "spark-connect-assembly-", "spark-connect") +else: + connect_jar = None -connect_jar = search_jar("connector/connect", "spark-connect-assembly-", "spark-connect") if connect_jar is None: connect_requirement_message = ( "Skipping all Spark Connect Python tests as the optional Spark Connect project was " @@ -38,7 +43,7 @@ os.environ["PYSPARK_SUBMIT_ARGS"] = " ".join([jars_args, plugin_args, existing_args]) connect_requirement_message = None # type: ignore -should_test_connect = connect_requirement_message is None +should_test_connect = connect_requirement_message is None and have_pandas class MockRemoteSession: