diff --git a/.github/workflows/pytest-pyspark.yml b/.github/workflows/pytest-pyspark.yml index af226297c2..659cb40619 100644 --- a/.github/workflows/pytest-pyspark.yml +++ b/.github/workflows/pytest-pyspark.yml @@ -35,3 +35,75 @@ jobs: run: uv pip freeze - name: Run pytest run: pytest tests --cov=narwhals/_spark_like --cov-fail-under=95 --runslow --constructors pyspark + + + pytest-pyspark-connect-constructor: + if: ${{ contains(github.event.pull_request.labels.*.name, 'pyspark-connect') || contains(github.event.pull_request.labels.*.name, 'release') }} + strategy: + matrix: + python-version: ["3.10", "3.11"] + os: [ubuntu-latest] + env: + SPARK_VERSION: 3.5.5 + SPARK_PORT: 15002 + SPARK_CONNECT: true + runs-on: ${{ matrix.os }} + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install uv + uses: astral-sh/setup-uv@v5 + with: + enable-cache: "true" + cache-suffix: ${{ matrix.python-version }} + cache-dependency-glob: "pyproject.toml" + + - name: Install Java 17 + uses: actions/setup-java@v4 + with: + distribution: zulu + java-version: 17 + + - name: install-reqs + run: uv pip install -e . --group core-tests --group extra --system + - name: install pyspark + run: echo "setuptools<78" | uv pip install -e . "pyspark[connect]==${SPARK_VERSION}" --system + - name: show-deps + run: uv pip freeze + + - name: Cache Spark + id: cache-spark + uses: actions/cache@v4 + with: + path: /opt/spark + key: spark-${{ env.SPARK_VERSION }}-bin-hadoop3 + + - name: Download Spark + if: steps.cache-spark.outputs.cache-hit != 'true' + run: | + wget https://archive.apache.org/dist/spark/spark-${SPARK_VERSION}/spark-${SPARK_VERSION}-bin-hadoop3.tgz + tar -xzf spark-${SPARK_VERSION}-bin-hadoop3.tgz + sudo mv spark-${SPARK_VERSION}-bin-hadoop3 /opt/spark + + - name: Set Spark env variables + run: | + echo "SPARK_HOME=/opt/spark" >> $GITHUB_ENV + echo "/opt/spark/bin" >> $GITHUB_PATH + + - name: Start Spark Connect server + run: | + $SPARK_HOME/sbin/start-connect-server.sh \ + --packages org.apache.spark:spark-connect_2.12:${SPARK_VERSION} \ + --conf spark.connect.grpc.binding.port=${SPARK_PORT} + sleep 5 + echo "Spark Connect server started" + + - name: Run pytest + run: pytest tests --cov=narwhals/_spark_like --cov-fail-under=95 --runslow --constructors "pyspark[connect]" + + - name: Stop Spark Connect server + if: always() + run: $SPARK_HOME/sbin/stop-connect-server.sh diff --git a/narwhals/_namespace.py b/narwhals/_namespace.py index 0e3931bfc9..d1a5d20e87 100644 --- a/narwhals/_namespace.py +++ b/narwhals/_namespace.py @@ -20,6 +20,7 @@ from narwhals.dependencies import get_pyarrow from narwhals.dependencies import is_dask_dataframe from narwhals.dependencies import is_duckdb_relation +from narwhals.dependencies import is_pyspark_connect_dataframe from narwhals.dependencies import is_pyspark_dataframe from narwhals.dependencies import is_sqlframe_dataframe from narwhals.utils import Implementation @@ -34,6 +35,7 @@ import polars as pl import pyarrow as pa import pyspark.sql as pyspark_sql + from pyspark.sql.connect.dataframe import DataFrame as PySparkConnectDataFrame from typing_extensions import TypeAlias from typing_extensions import TypeIs @@ -72,7 +74,10 @@ _PandasLike, Implementation.PANDAS, Implementation.CUDF, Implementation.MODIN ] SparkLike: TypeAlias = Literal[ - _SparkLike, Implementation.PYSPARK, Implementation.SQLFRAME + _SparkLike, + Implementation.PYSPARK, + Implementation.SQLFRAME, + Implementation.PYSPARK_CONNECT, ] EagerOnly: TypeAlias = "PandasLike | Arrow" EagerAllowed: TypeAlias = "EagerOnly | Polars" @@ -111,7 +116,10 @@ class _ModinSeries(Protocol): _NativePandasLike: TypeAlias = "_NativePandas | _NativeCuDF | _NativeModin" _NativeSQLFrame: TypeAlias = "SQLFrameDataFrame" _NativePySpark: TypeAlias = "pyspark_sql.DataFrame" - _NativeSparkLike: TypeAlias = "_NativeSQLFrame | _NativePySpark" + _NativePySparkConnect: TypeAlias = "PySparkConnectDataFrame" + _NativeSparkLike: TypeAlias = ( + "_NativeSQLFrame | _NativePySpark | _NativePySparkConnect" + ) NativeKnown: TypeAlias = "_NativePolars | _NativeArrow | _NativePandasLike | _NativeSparkLike | _NativeDuckDB | _NativeDask" NativeUnknown: TypeAlias = ( @@ -292,6 +300,8 @@ def from_native_object( return cls.from_backend( Implementation.SQLFRAME if is_native_sqlframe(native) + else Implementation.PYSPARK_CONNECT + if is_native_pyspark_connect(native) else Implementation.PYSPARK ) elif is_native_dask(native): @@ -326,6 +336,7 @@ def is_native_dask(obj: Any) -> TypeIs[_NativeDask]: is_native_duckdb: _Guard[_NativeDuckDB] = is_duckdb_relation is_native_sqlframe: _Guard[_NativeSQLFrame] = is_sqlframe_dataframe is_native_pyspark: _Guard[_NativePySpark] = is_pyspark_dataframe +is_native_pyspark_connect: _Guard[_NativePySparkConnect] = is_pyspark_connect_dataframe def is_native_pandas(obj: Any) -> TypeIs[_NativePandas]: @@ -351,4 +362,8 @@ def is_native_pandas_like(obj: Any) -> TypeIs[_NativePandasLike]: def is_native_spark_like(obj: Any) -> TypeIs[_NativeSparkLike]: - return is_native_pyspark(obj) or is_native_sqlframe(obj) + return ( + is_native_sqlframe(obj) + or is_native_pyspark(obj) + or is_native_pyspark_connect(obj) + ) diff --git a/narwhals/_spark_like/dataframe.py b/narwhals/_spark_like/dataframe.py index 9cde2f3b05..9493cb8dfc 100644 --- a/narwhals/_spark_like/dataframe.py +++ b/narwhals/_spark_like/dataframe.py @@ -141,10 +141,35 @@ def _with_native(self, df: SQLFrameDataFrame) -> Self: implementation=self._implementation, ) + def _to_arrow_schema(self) -> pa.Schema: # pragma: no cover + import pyarrow as pa # ignore-banned-import + + from narwhals._arrow.utils import narwhals_to_native_dtype + + schema: list[tuple[str, pa.DataType]] = [] + nw_schema = self.collect_schema() + native_schema = self.native.schema + for key, value in nw_schema.items(): + try: + native_dtype = narwhals_to_native_dtype(value, self._version) + except Exception as exc: # noqa: BLE001,PERF203 + native_spark_dtype = native_schema[key].dataType # type: ignore[index] + # If we can't convert the type, just set it to `pa.null`, and warn. + # Avoid the warning if we're starting from PySpark's void type. + # We can avoid the check when we introduce `nw.Null` dtype. + null_type = self._native_dtypes.NullType # pyright: ignore[reportAttributeAccessIssue] + if not isinstance(native_spark_dtype, null_type): + warnings.warn( + f"Could not convert dtype {native_spark_dtype} to PyArrow dtype, {exc!r}", + stacklevel=find_stacklevel(), + ) + schema.append((key, pa.null())) + else: + schema.append((key, native_dtype)) + return pa.schema(schema) + def _collect_to_arrow(self) -> pa.Table: - if self._implementation is Implementation.PYSPARK and self._backend_version < ( - 4, - ): + if self._implementation.is_pyspark() and self._backend_version < (4,): import pyarrow as pa # ignore-banned-import try: @@ -152,32 +177,17 @@ def _collect_to_arrow(self) -> pa.Table: except ValueError as exc: if "at least one RecordBatch" in str(exc): # Empty dataframe - from narwhals._arrow.utils import narwhals_to_native_dtype - - data: dict[str, list[Any]] = {} - schema: list[tuple[str, pa.DataType]] = [] - current_schema = self.collect_schema() - for key, value in current_schema.items(): - data[key] = [] - try: - native_dtype = narwhals_to_native_dtype(value, self._version) - except Exception as exc: # noqa: BLE001 - native_spark_dtype = self.native.schema[key].dataType # type: ignore[index] - # If we can't convert the type, just set it to `pa.null`, and warn. - # Avoid the warning if we're starting from PySpark's void type. - # We can avoid the check when we introduce `nw.Null` dtype. - null_type = self._native_dtypes.NullType # pyright: ignore[reportAttributeAccessIssue] - if not isinstance(native_spark_dtype, null_type): - warnings.warn( - f"Could not convert dtype {native_spark_dtype} to PyArrow dtype, {exc!r}", - stacklevel=find_stacklevel(), - ) - schema.append((key, pa.null())) - else: - schema.append((key, native_dtype)) - return pa.Table.from_pydict(data, schema=pa.schema(schema)) + + data: dict[str, list[Any]] = {k: [] for k in self.columns} + pa_schema = self._to_arrow_schema() + return pa.Table.from_pydict(data, schema=pa_schema) else: # pragma: no cover raise + elif self._implementation.is_pyspark_connect() and self._backend_version < (4,): + import pyarrow as pa # ignore-banned-import + + pa_schema = self._to_arrow_schema() + return pa.Table.from_pandas(self.native.toPandas(), schema=pa_schema) else: return self.native.toArrow() @@ -293,7 +303,7 @@ def drop(self, columns: Sequence[str], *, strict: bool) -> Self: return self._with_native(self.native.drop(*columns_to_drop)) def head(self, n: int) -> Self: - return self._with_native(self.native.limit(num=n)) + return self._with_native(self.native.limit(n)) def group_by( self, keys: Sequence[str] | Sequence[SparkLikeExpr], *, drop_null_keys: bool @@ -445,7 +455,7 @@ def explode(self, columns: Sequence[str]) -> Self: ) raise NotImplementedError(msg) - if self._implementation.is_pyspark(): + if self._implementation.is_pyspark() or self._implementation.is_pyspark_connect(): return self._with_native( self.native.select( *[ diff --git a/narwhals/_spark_like/expr.py b/narwhals/_spark_like/expr.py index 088468b02a..844861e2a8 100644 --- a/narwhals/_spark_like/expr.py +++ b/narwhals/_spark_like/expr.py @@ -75,7 +75,8 @@ def broadcast(self, kind: Literal[ExprKind.AGGREGATION, ExprKind.LITERAL]) -> Se def func(df: SparkLikeLazyFrame) -> Sequence[Column]: return [ - result.over(df._Window().partitionBy(df._F.lit(1))) for result in self(df) + result.over(self._Window().partitionBy(self._F.lit(1))) + for result in self(df) ] return self.__class__( @@ -438,7 +439,8 @@ def mean(self) -> Self: def median(self) -> Self: def _median(_input: Column) -> Column: if ( - self._implementation.is_pyspark() + self._implementation + in {Implementation.PYSPARK, Implementation.PYSPARK_CONNECT} and (pyspark := get_pyspark()) is not None and parse_version(pyspark) < (3, 4) ): # pragma: no cover @@ -772,7 +774,7 @@ def _rank(_input: Column) -> Column: else: order_by_cols = [self._F.asc_nulls_last(_input)] - window = self._Window().orderBy(order_by_cols) + window = self._Window().partitionBy(self._F.lit(1)).orderBy(order_by_cols) count_window = self._Window().partitionBy(_input) if method == "max": diff --git a/narwhals/_spark_like/utils.py b/narwhals/_spark_like/utils.py index bccb199aca..6c6ceacfca 100644 --- a/narwhals/_spark_like/utils.py +++ b/narwhals/_spark_like/utils.py @@ -244,6 +244,10 @@ def import_functions(implementation: Implementation, /) -> ModuleType: if implementation is Implementation.PYSPARK: from pyspark.sql import functions + return functions + if implementation is Implementation.PYSPARK_CONNECT: + from pyspark.sql.connect import functions + return functions from sqlframe.base.session import _BaseSession @@ -254,6 +258,10 @@ def import_native_dtypes(implementation: Implementation, /) -> ModuleType: if implementation is Implementation.PYSPARK: from pyspark.sql import types + return types + if implementation is Implementation.PYSPARK_CONNECT: + from pyspark.sql.connect import types + return types from sqlframe.base.session import _BaseSession @@ -264,6 +272,11 @@ def import_window(implementation: Implementation, /) -> type[Any]: if implementation is Implementation.PYSPARK: from pyspark.sql import Window + return Window + + if implementation is Implementation.PYSPARK_CONNECT: + from pyspark.sql.connect.window import Window + return Window from sqlframe.base.session import _BaseSession diff --git a/narwhals/dependencies.py b/narwhals/dependencies.py index fb8154685b..6b602995f2 100644 --- a/narwhals/dependencies.py +++ b/narwhals/dependencies.py @@ -17,6 +17,7 @@ import polars as pl import pyarrow as pa import pyspark.sql as pyspark_sql + from pyspark.sql.connect.dataframe import DataFrame as PySparkConnectDataFrame from typing_extensions import TypeGuard from typing_extensions import TypeIs @@ -112,6 +113,11 @@ def get_pyspark_sql() -> Any: return sys.modules.get("pyspark.sql", None) +def get_pyspark_connect() -> Any: + """Get pyspark.sql.connect module (if already imported - else return None).""" + return sys.modules.get("pyspark.sql.connect", None) + + def get_sqlframe() -> Any: """Get sqlframe module (if already imported - else return None).""" return sys.modules.get("sqlframe", None) @@ -230,6 +236,14 @@ def is_pyspark_dataframe(df: Any) -> TypeIs[pyspark_sql.DataFrame]: ) +def is_pyspark_connect_dataframe(df: Any) -> TypeIs[PySparkConnectDataFrame]: + """Check whether `df` is a PySpark Connect DataFrame without importing PySpark.""" + return bool( + (pyspark_connect := get_pyspark_connect()) is not None + and isinstance(df, pyspark_connect.dataframe.DataFrame) + ) + + def is_sqlframe_dataframe(df: Any) -> TypeIs[SQLFrameDataFrame]: """Check whether `df` is a SQLFrame DataFrame without importing SQLFrame.""" if get_sqlframe() is not None: diff --git a/narwhals/utils.py b/narwhals/utils.py index eb4ef8328e..715670cabe 100644 --- a/narwhals/utils.py +++ b/narwhals/utils.py @@ -35,6 +35,7 @@ from narwhals.dependencies import get_polars from narwhals.dependencies import get_pyarrow from narwhals.dependencies import get_pyspark +from narwhals.dependencies import get_pyspark_connect from narwhals.dependencies import get_pyspark_sql from narwhals.dependencies import get_sqlframe from narwhals.dependencies import is_cudf_series @@ -233,6 +234,8 @@ class Implementation(Enum): """Ibis implementation.""" SQLFRAME = auto() """SQLFrame implementation.""" + PYSPARK_CONNECT = auto() + """PySpark Connect implementation.""" UNKNOWN = auto() """Unknown implementation.""" @@ -260,6 +263,7 @@ def from_native_namespace( get_duckdb(): Implementation.DUCKDB, get_ibis(): Implementation.IBIS, get_sqlframe(): Implementation.SQLFRAME, + get_pyspark_connect(): Implementation.PYSPARK_CONNECT, } return mapping.get(native_namespace, Implementation.UNKNOWN) @@ -286,6 +290,7 @@ def from_string( "duckdb": Implementation.DUCKDB, "ibis": Implementation.IBIS, "sqlframe": Implementation.SQLFRAME, + "pyspark_connect": Implementation.PYSPARK_CONNECT, } return mapping.get(backend_name, Implementation.UNKNOWN) @@ -354,6 +359,11 @@ def to_native_namespace(self) -> ModuleType: return sqlframe + if self is Implementation.PYSPARK_CONNECT: # pragma: no cover + import pyspark.sql # ignore-banned-import + + return pyspark.sql + msg = "Not supported Implementation" # pragma: no cover raise AssertionError(msg) @@ -407,7 +417,11 @@ def is_spark_like(self) -> bool: >>> df.implementation.is_spark_like() False """ - return self in {Implementation.PYSPARK, Implementation.SQLFRAME} + return self in { + Implementation.PYSPARK, + Implementation.SQLFRAME, + Implementation.PYSPARK_CONNECT, + } def is_polars(self) -> bool: """Return whether implementation is Polars. @@ -473,6 +487,22 @@ def is_pyspark(self) -> bool: """ return self is Implementation.PYSPARK # pragma: no cover + def is_pyspark_connect(self) -> bool: + """Return whether implementation is PySpark. + + Returns: + Boolean. + + Examples: + >>> import polars as pl + >>> import narwhals as nw + >>> df_native = pl.DataFrame({"a": [1, 2, 3]}) + >>> df = nw.from_native(df_native) + >>> df.implementation.is_pyspark_connect() + False + """ + return self is Implementation.PYSPARK_CONNECT # pragma: no cover + def is_pyarrow(self) -> bool: """Return whether implementation is PyArrow. @@ -571,6 +601,7 @@ def _alias(self) -> LiteralString: Implementation.PYSPARK: "PySpark", Implementation.DUCKDB: "DuckDB", Implementation.SQLFRAME: "SQLFrame", + Implementation.PYSPARK_CONNECT: "PySpark Connect", } return mapping[self] @@ -579,11 +610,12 @@ def _backend_version(self) -> tuple[int, ...]: into_version: Any if self not in { Implementation.PYSPARK, + Implementation.PYSPARK_CONNECT, Implementation.DASK, Implementation.SQLFRAME, }: into_version = native - elif self is Implementation.PYSPARK: + elif self in {Implementation.PYSPARK, Implementation.PYSPARK_CONNECT}: into_version = get_pyspark() # pragma: no cover elif self is Implementation.DASK: into_version = get_dask() @@ -600,6 +632,7 @@ def _backend_version(self) -> tuple[int, ...]: Implementation.CUDF: (24, 10), Implementation.PYARROW: (11,), Implementation.PYSPARK: (3, 5), + Implementation.PYSPARK_CONNECT: (3, 5), Implementation.POLARS: (0, 20, 3), Implementation.DASK: (2024, 8), Implementation.DUCKDB: (1,), diff --git a/pyproject.toml b/pyproject.toml index 6d3332cecf..d6c4dec71e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,6 +39,7 @@ modin = ["modin"] cudf = ["cudf>=24.10.0"] pyarrow = ["pyarrow>=11.0.0"] pyspark = ["pyspark>=3.5.0"] +pyspark-connect = ["pyspark[connect]>=3.5.0"] polars = ["polars>=0.20.3"] dask = ["dask[dataframe]>=2024.8"] duckdb = ["duckdb>=1.0"] @@ -257,14 +258,17 @@ exclude_also = [ "if .*implementation is Implementation.CUDF", "if .*implementation is Implementation.MODIN", "if .*implementation is Implementation.PYSPARK", + "if .*implementation is Implementation.PYSPARK_CONNECT", "if .*implementation.is_cudf", "if .*implementation.is_modin", "if .*implementation.is_pyspark", + "if .*implementation.is_pyspark_connect", 'request.applymarker\(pytest.mark.xfail', 'backend_version <', 'if "cudf" in str\(constructor', 'if "modin" in str\(constructor', 'if "pyspark" in str\(constructor', + 'if "pyspark_connect" in str\(constructor', 'pytest.skip\(' ] diff --git a/tests/conftest.py b/tests/conftest.py index 3575d986de..a7d1dc850f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -163,7 +163,14 @@ def pyspark_lazy_constructor() -> Callable[[Data], PySparkDataFrame]: # pragma: import warnings from atexit import register - from pyspark.sql import SparkSession + is_spark_connect = bool(os.environ.get("SPARK_CONNECT", None)) + + if TYPE_CHECKING: + from pyspark.sql import SparkSession + elif is_spark_connect: + from pyspark.sql.connect.session import SparkSession + else: + from pyspark.sql import SparkSession with warnings.catch_warnings(): # The spark session seems to trigger a polars warning. @@ -171,12 +178,14 @@ def pyspark_lazy_constructor() -> Callable[[Data], PySparkDataFrame]: # pragma: warnings.filterwarnings( "ignore", r"Using fork\(\) can cause Polars", category=RuntimeWarning ) - builder = cast("SparkSession.Builder", SparkSession.builder) + builder = cast("SparkSession.Builder", SparkSession.builder).appName("unit-tests") + session = ( - builder.appName("unit-tests") - .master("local[1]") - .config("spark.ui.enabled", "false") - # executing one task at a time makes the tests faster + ( + builder.remote(f"sc://localhost:{os.environ.get('SPARK_PORT', '15002')}") + if is_spark_connect + else builder.master("local[1]").config("spark.ui.enabled", "false") + ) .config("spark.default.parallelism", "1") .config("spark.sql.shuffle.partitions", "2") # common timezone for all tests environments @@ -237,7 +246,12 @@ def pytest_generate_tests(metafunc: pytest.Metafunc) -> None: selected_constructors = [ x for x in selected_constructors - if x not in GPU_CONSTRUCTORS and x != "modin" # too slow + if x not in GPU_CONSTRUCTORS + and x + not in { + "modin", # too slow + "spark[connect]", # complex local setup; can't run together with local spark + } ] else: # pragma: no cover opt = cast("str", metafunc.config.getoption("constructors")) @@ -259,7 +273,7 @@ def pytest_generate_tests(metafunc: pytest.Metafunc) -> None: eager_constructors.append(EAGER_CONSTRUCTORS[constructor]) eager_constructors_ids.append(constructor) constructors.append(EAGER_CONSTRUCTORS[constructor]) - elif constructor == "pyspark": # pragma: no cover + elif constructor in {"pyspark", "pyspark[connect]"}: # pragma: no cover constructors.append(pyspark_lazy_constructor()) elif constructor in LAZY_CONSTRUCTORS: constructors.append(LAZY_CONSTRUCTORS[constructor]) diff --git a/tests/frame/select_test.py b/tests/frame/select_test.py index 3335ce655e..010fd7bfe3 100644 --- a/tests/frame/select_test.py +++ b/tests/frame/select_test.py @@ -165,7 +165,7 @@ def test_select_duplicates(constructor: Constructor) -> None: # cudf already raises its own error pytest.skip() df = nw.from_native(constructor({"a": [1, 2]})).lazy() - with pytest.raises(ValueError, match="Expected unique|duplicate|more than one"): + with pytest.raises(ValueError, match="Expected unique|[Dd]uplicate|more than one"): df.select("a", nw.col("a") + 1).collect() diff --git a/tests/read_scan_test.py b/tests/read_scan_test.py index 7a758ff78b..6d9a9af10c 100644 --- a/tests/read_scan_test.py +++ b/tests/read_scan_test.py @@ -1,9 +1,11 @@ from __future__ import annotations +import os from typing import TYPE_CHECKING from typing import Any from typing import Literal from typing import Mapping +from typing import cast import pandas as pd import pytest @@ -73,21 +75,27 @@ def test_scan_csv( "header": True, } elif "pyspark" in str(constructor): - from pyspark.sql import SparkSession + if is_spark_connect := os.environ.get("SPARK_CONNECT", None): + from pyspark.sql.connect.session import SparkSession + else: + from pyspark.sql import SparkSession + + builder = cast("SparkSession.Builder", SparkSession.builder).appName("unit-tests") + session = ( + ( + builder.remote(f"sc://localhost:{os.environ.get('SPARK_PORT', '15002')}") + if is_spark_connect + else builder.master("local[1]").config("spark.ui.enabled", "false") + ) + .config("spark.default.parallelism", "1") + .config("spark.sql.shuffle.partitions", "2") + # common timezone for all tests environments + .config("spark.sql.session.timeZone", "UTC") + .getOrCreate() + ) + + kwargs = {"session": session, "inferSchema": True, "header": True} - kwargs = { - "session": ( - SparkSession.builder.appName("unit-tests") # pyright: ignore[reportAttributeAccessIssue] - .master("local[1]") - .config("spark.ui.enabled", "false") - .config("spark.default.parallelism", "1") - .config("spark.sql.shuffle.partitions", "2") - .config("spark.sql.session.timeZone", "UTC") - .getOrCreate() - ), - "inferSchema": True, - "header": True, - } else: kwargs = {} @@ -148,21 +156,28 @@ def test_scan_parquet( from sqlframe.duckdb import DuckDBSession kwargs = {"session": DuckDBSession(), "inferSchema": True} - elif "pyspark" in str(constructor): - from pyspark.sql import SparkSession - kwargs = { - "session": ( - SparkSession.builder.appName("unit-tests") # pyright: ignore[reportAttributeAccessIssue] - .master("local[1]") - .config("spark.ui.enabled", "false") - .config("spark.default.parallelism", "1") - .config("spark.sql.shuffle.partitions", "2") - .config("spark.sql.session.timeZone", "UTC") - .getOrCreate() - ), - "inferSchema": True, - } + elif "pyspark" in str(constructor): + if is_spark_connect := os.environ.get("SPARK_CONNECT", None): + from pyspark.sql.connect.session import SparkSession + else: + from pyspark.sql import SparkSession + + builder = cast("SparkSession.Builder", SparkSession.builder).appName("unit-tests") + session = ( + ( + builder.remote(f"sc://localhost:{os.environ.get('SPARK_PORT', '15002')}") + if is_spark_connect + else builder.master("local[1]").config("spark.ui.enabled", "false") + ) + .config("spark.default.parallelism", "1") + .config("spark.sql.shuffle.partitions", "2") + # common timezone for all tests environments + .config("spark.sql.session.timeZone", "UTC") + .getOrCreate() + ) + + kwargs = {"session": session, "inferSchema": True, "header": True} else: kwargs = {} df_pl = pl.DataFrame(data)