diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index 1f67f4c49de97..f47035df490d3 100644 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -32,13 +32,15 @@ if have_pandas: from pyspark.sql.connect.session import SparkSession as RemoteSparkSession from pyspark.sql.connect.client import ChannelBuilder + from pyspark.sql.connect.dataframe import DataFrame as CDataFrame from pyspark.sql.connect.function_builder import udf from pyspark.sql.connect.functions import lit, col + from pyspark.testing.pandasutils import PandasOnSparkTestCase +else: + from pyspark.testing.sqlutils import ReusedSQLTestCase as PandasOnSparkTestCase # type: ignore from pyspark.sql.dataframe import DataFrame import pyspark.sql.functions -from pyspark.sql.connect.dataframe import DataFrame as CDataFrame from pyspark.testing.connectutils import should_test_connect, connect_requirement_message -from pyspark.testing.pandasutils import PandasOnSparkTestCase from pyspark.testing.utils import ReusedPySparkTestCase @@ -881,6 +883,7 @@ def test_crossjoin(self): ) +@unittest.skipIf(not should_test_connect, connect_requirement_message) class ChannelBuilderTests(ReusedPySparkTestCase): def test_invalid_connection_strings(self): invalid = [ diff --git a/python/pyspark/sql/tests/connect/test_connect_function.py b/python/pyspark/sql/tests/connect/test_connect_function.py index 6d06421d0848f..ac0718fd6308b 100644 --- a/python/pyspark/sql/tests/connect/test_connect_function.py +++ b/python/pyspark/sql/tests/connect/test_connect_function.py @@ -26,7 +26,7 @@ from pyspark.sql.connect.session import SparkSession as RemoteSparkSession from pyspark.testing.pandasutils import PandasOnSparkTestCase else: - from pyspark.testing.sqlutils import ReusedSQLTestCase as PandasOnSparkTestCase + from pyspark.testing.sqlutils import ReusedSQLTestCase as PandasOnSparkTestCase # type: ignore from pyspark.sql.dataframe import DataFrame from pyspark.testing.connectutils import should_test_connect, connect_requirement_message from pyspark.testing.utils import ReusedPySparkTestCase diff --git a/python/pyspark/testing/connectutils.py b/python/pyspark/testing/connectutils.py index 05df6b02e6726..7f4250613cc20 100644 --- a/python/pyspark/testing/connectutils.py +++ b/python/pyspark/testing/connectutils.py @@ -97,6 +97,7 @@ def _session_sql(cls, query: str) -> "DataFrame": return DataFrame.withPlan(SQL(query), cls.connect) # type: ignore if have_pandas: + @classmethod def _with_plan(cls, plan: LogicalPlan) -> "DataFrame": return DataFrame.withPlan(plan, cls.connect) # type: ignore