Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions .github/workflows/pytest-pyspark.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,75 @@ jobs:
run: uv pip freeze
- name: Run pytest
run: pytest tests --cov=narwhals/_spark_like --cov-fail-under=95 --runslow --constructors pyspark


pytest-pyspark-connect-constructor:
if: ${{ contains(github.event.pull_request.labels.*.name, 'pyspark-connect') || contains(github.event.pull_request.labels.*.name, 'release') }}
strategy:
matrix:
python-version: ["3.10", "3.11"]
os: [ubuntu-latest]
env:
SPARK_VERSION: 3.5.5
SPARK_PORT: 15002
SPARK_CONNECT: true
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}

- name: Install uv
uses: astral-sh/setup-uv@v5
with:
enable-cache: "true"
cache-suffix: ${{ matrix.python-version }}
cache-dependency-glob: "pyproject.toml"

- name: Install Java 17
uses: actions/setup-java@v4
with:
distribution: zulu
java-version: 17

- name: install-reqs
run: uv pip install -e . --group core-tests --group extra --system
- name: install pyspark
run: echo "setuptools<78" | uv pip install -e . "pyspark[connect]==${SPARK_VERSION}" --system
- name: show-deps
run: uv pip freeze

- name: Cache Spark
id: cache-spark
uses: actions/cache@v4
with:
path: /opt/spark
key: spark-${{ env.SPARK_VERSION }}-bin-hadoop3
Comment on lines +77 to +82
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice


- name: Download Spark
if: steps.cache-spark.outputs.cache-hit != 'true'
run: |
wget https://archive.apache.org/dist/spark/spark-${SPARK_VERSION}/spark-${SPARK_VERSION}-bin-hadoop3.tgz
tar -xzf spark-${SPARK_VERSION}-bin-hadoop3.tgz
sudo mv spark-${SPARK_VERSION}-bin-hadoop3 /opt/spark

- name: Set Spark env variables
run: |
echo "SPARK_HOME=/opt/spark" >> $GITHUB_ENV
echo "/opt/spark/bin" >> $GITHUB_PATH

- name: Start Spark Connect server
run: |
$SPARK_HOME/sbin/start-connect-server.sh \
--packages org.apache.spark:spark-connect_2.12:${SPARK_VERSION} \
--conf spark.connect.grpc.binding.port=${SPARK_PORT}
sleep 5
echo "Spark Connect server started"

- name: Run pytest
run: pytest tests --cov=narwhals/_spark_like --cov-fail-under=95 --runslow --constructors "pyspark[connect]"

- name: Stop Spark Connect server
if: always()
run: $SPARK_HOME/sbin/stop-connect-server.sh
21 changes: 18 additions & 3 deletions narwhals/_namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from narwhals.dependencies import get_pyarrow
from narwhals.dependencies import is_dask_dataframe
from narwhals.dependencies import is_duckdb_relation
from narwhals.dependencies import is_pyspark_connect_dataframe
from narwhals.dependencies import is_pyspark_dataframe
from narwhals.dependencies import is_sqlframe_dataframe
from narwhals.utils import Implementation
Expand All @@ -34,6 +35,7 @@
import polars as pl
import pyarrow as pa
import pyspark.sql as pyspark_sql
from pyspark.sql.connect.dataframe import DataFrame as PySparkConnectDataFrame
from typing_extensions import TypeAlias
from typing_extensions import TypeIs

Expand Down Expand Up @@ -72,7 +74,10 @@
_PandasLike, Implementation.PANDAS, Implementation.CUDF, Implementation.MODIN
]
SparkLike: TypeAlias = Literal[
_SparkLike, Implementation.PYSPARK, Implementation.SQLFRAME
_SparkLike,
Implementation.PYSPARK,
Implementation.SQLFRAME,
Implementation.PYSPARK_CONNECT,
]
EagerOnly: TypeAlias = "PandasLike | Arrow"
EagerAllowed: TypeAlias = "EagerOnly | Polars"
Expand Down Expand Up @@ -111,7 +116,10 @@ class _ModinSeries(Protocol):
_NativePandasLike: TypeAlias = "_NativePandas | _NativeCuDF | _NativeModin"
_NativeSQLFrame: TypeAlias = "SQLFrameDataFrame"
_NativePySpark: TypeAlias = "pyspark_sql.DataFrame"
_NativeSparkLike: TypeAlias = "_NativeSQLFrame | _NativePySpark"
_NativePySparkConnect: TypeAlias = "PySparkConnectDataFrame"
_NativeSparkLike: TypeAlias = (
"_NativeSQLFrame | _NativePySpark | _NativePySparkConnect"
)

NativeKnown: TypeAlias = "_NativePolars | _NativeArrow | _NativePandasLike | _NativeSparkLike | _NativeDuckDB | _NativeDask"
NativeUnknown: TypeAlias = (
Expand Down Expand Up @@ -292,6 +300,8 @@ def from_native_object(
return cls.from_backend(
Implementation.SQLFRAME
if is_native_sqlframe(native)
else Implementation.PYSPARK_CONNECT
if is_native_pyspark_connect(native)
else Implementation.PYSPARK
)
elif is_native_dask(native):
Expand Down Expand Up @@ -326,6 +336,7 @@ def is_native_dask(obj: Any) -> TypeIs[_NativeDask]:
is_native_duckdb: _Guard[_NativeDuckDB] = is_duckdb_relation
is_native_sqlframe: _Guard[_NativeSQLFrame] = is_sqlframe_dataframe
is_native_pyspark: _Guard[_NativePySpark] = is_pyspark_dataframe
is_native_pyspark_connect: _Guard[_NativePySparkConnect] = is_pyspark_connect_dataframe


def is_native_pandas(obj: Any) -> TypeIs[_NativePandas]:
Expand All @@ -351,4 +362,8 @@ def is_native_pandas_like(obj: Any) -> TypeIs[_NativePandasLike]:


def is_native_spark_like(obj: Any) -> TypeIs[_NativeSparkLike]:
return is_native_pyspark(obj) or is_native_sqlframe(obj)
return (
is_native_sqlframe(obj)
or is_native_pyspark(obj)
or is_native_pyspark_connect(obj)
)
68 changes: 39 additions & 29 deletions narwhals/_spark_like/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,43 +141,53 @@ def _with_native(self, df: SQLFrameDataFrame) -> Self:
implementation=self._implementation,
)

def _to_arrow_schema(self) -> pa.Schema: # pragma: no cover
import pyarrow as pa # ignore-banned-import

from narwhals._arrow.utils import narwhals_to_native_dtype

schema: list[tuple[str, pa.DataType]] = []
nw_schema = self.collect_schema()
native_schema = self.native.schema
for key, value in nw_schema.items():
try:
native_dtype = narwhals_to_native_dtype(value, self._version)
except Exception as exc: # noqa: BLE001,PERF203
native_spark_dtype = native_schema[key].dataType # type: ignore[index]
Comment on lines +152 to +156
Copy link
Member

@dangotbanned dangotbanned Apr 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@FBruzzesi Could you address this performance issue (PERF203) or leave a comment on why it is unavoidable please?

I'd personally try to also avoid (BLE001) and using exception handling entirely - but understand they were already here πŸ™‚

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In principle we could first try to check if self.collect_schema() has any unknown type. If it doesn't then it should be possible to call .to_arrow() - if it does then this workaround is needed

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dangotbanned sorry for the direct ping - let's figure out what to write - to me the explanation be in ln158 is quite good.

I will need your approval to merge 😎

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey sorry I lost this @FBruzzesi

I started trying to address it, but couldn't get the tests working locally πŸ˜”

# If we can't convert the type, just set it to `pa.null`, and warn.
# Avoid the warning if we're starting from PySpark's void type.
# We can avoid the check when we introduce `nw.Null` dtype.
null_type = self._native_dtypes.NullType # pyright: ignore[reportAttributeAccessIssue]
if not isinstance(native_spark_dtype, null_type):
warnings.warn(
f"Could not convert dtype {native_spark_dtype} to PyArrow dtype, {exc!r}",
stacklevel=find_stacklevel(),
)
schema.append((key, pa.null()))
else:
schema.append((key, native_dtype))
return pa.schema(schema)

def _collect_to_arrow(self) -> pa.Table:
if self._implementation is Implementation.PYSPARK and self._backend_version < (
4,
):
if self._implementation.is_pyspark() and self._backend_version < (4,):
import pyarrow as pa # ignore-banned-import

try:
return pa.Table.from_batches(self.native._collect_as_arrow())
except ValueError as exc:
if "at least one RecordBatch" in str(exc):
# Empty dataframe
from narwhals._arrow.utils import narwhals_to_native_dtype

data: dict[str, list[Any]] = {}
schema: list[tuple[str, pa.DataType]] = []
current_schema = self.collect_schema()
for key, value in current_schema.items():
data[key] = []
try:
native_dtype = narwhals_to_native_dtype(value, self._version)
except Exception as exc: # noqa: BLE001
native_spark_dtype = self.native.schema[key].dataType # type: ignore[index]
# If we can't convert the type, just set it to `pa.null`, and warn.
# Avoid the warning if we're starting from PySpark's void type.
# We can avoid the check when we introduce `nw.Null` dtype.
null_type = self._native_dtypes.NullType # pyright: ignore[reportAttributeAccessIssue]
if not isinstance(native_spark_dtype, null_type):
warnings.warn(
f"Could not convert dtype {native_spark_dtype} to PyArrow dtype, {exc!r}",
stacklevel=find_stacklevel(),
)
schema.append((key, pa.null()))
else:
schema.append((key, native_dtype))
return pa.Table.from_pydict(data, schema=pa.schema(schema))

data: dict[str, list[Any]] = {k: [] for k in self.columns}
pa_schema = self._to_arrow_schema()
return pa.Table.from_pydict(data, schema=pa_schema)
else: # pragma: no cover
raise
elif self._implementation.is_pyspark_connect() and self._backend_version < (4,):
import pyarrow as pa # ignore-banned-import

pa_schema = self._to_arrow_schema()
return pa.Table.from_pandas(self.native.toPandas(), schema=pa_schema)
else:
return self.native.toArrow()

Expand Down Expand Up @@ -293,7 +303,7 @@ def drop(self, columns: Sequence[str], *, strict: bool) -> Self:
return self._with_native(self.native.drop(*columns_to_drop))

def head(self, n: int) -> Self:
return self._with_native(self.native.limit(num=n))
return self._with_native(self.native.limit(n))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No argument named num πŸ€·πŸΌβ€β™€οΈ


def group_by(
self, keys: Sequence[str] | Sequence[SparkLikeExpr], *, drop_null_keys: bool
Expand Down Expand Up @@ -445,7 +455,7 @@ def explode(self, columns: Sequence[str]) -> Self:
)
raise NotImplementedError(msg)

if self._implementation.is_pyspark():
if self._implementation.is_pyspark() or self._implementation.is_pyspark_connect():
return self._with_native(
self.native.select(
*[
Expand Down
8 changes: 5 additions & 3 deletions narwhals/_spark_like/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ def broadcast(self, kind: Literal[ExprKind.AGGREGATION, ExprKind.LITERAL]) -> Se

def func(df: SparkLikeLazyFrame) -> Sequence[Column]:
return [
result.over(df._Window().partitionBy(df._F.lit(1))) for result in self(df)
result.over(self._Window().partitionBy(self._F.lit(1)))
for result in self(df)
]

return self.__class__(
Expand Down Expand Up @@ -438,7 +439,8 @@ def mean(self) -> Self:
def median(self) -> Self:
def _median(_input: Column) -> Column:
if (
self._implementation.is_pyspark()
self._implementation
in {Implementation.PYSPARK, Implementation.PYSPARK_CONNECT}
and (pyspark := get_pyspark()) is not None
and parse_version(pyspark) < (3, 4)
): # pragma: no cover
Expand Down Expand Up @@ -772,7 +774,7 @@ def _rank(_input: Column) -> Column:
else:
order_by_cols = [self._F.asc_nulls_last(_input)]

window = self._Window().orderBy(order_by_cols)
window = self._Window().partitionBy(self._F.lit(1)).orderBy(order_by_cols)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left over from #2429

count_window = self._Window().partitionBy(_input)

if method == "max":
Expand Down
13 changes: 13 additions & 0 deletions narwhals/_spark_like/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,10 @@ def import_functions(implementation: Implementation, /) -> ModuleType:
if implementation is Implementation.PYSPARK:
from pyspark.sql import functions

return functions
if implementation is Implementation.PYSPARK_CONNECT:
from pyspark.sql.connect import functions

return functions
from sqlframe.base.session import _BaseSession

Expand All @@ -254,6 +258,10 @@ def import_native_dtypes(implementation: Implementation, /) -> ModuleType:
if implementation is Implementation.PYSPARK:
from pyspark.sql import types

return types
if implementation is Implementation.PYSPARK_CONNECT:
from pyspark.sql.connect import types

return types
from sqlframe.base.session import _BaseSession

Expand All @@ -264,6 +272,11 @@ def import_window(implementation: Implementation, /) -> type[Any]:
if implementation is Implementation.PYSPARK:
from pyspark.sql import Window

return Window

if implementation is Implementation.PYSPARK_CONNECT:
from pyspark.sql.connect.window import Window

return Window
from sqlframe.base.session import _BaseSession

Expand Down
14 changes: 14 additions & 0 deletions narwhals/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import polars as pl
import pyarrow as pa
import pyspark.sql as pyspark_sql
from pyspark.sql.connect.dataframe import DataFrame as PySparkConnectDataFrame
from typing_extensions import TypeGuard
from typing_extensions import TypeIs

Expand Down Expand Up @@ -112,6 +113,11 @@ def get_pyspark_sql() -> Any:
return sys.modules.get("pyspark.sql", None)


def get_pyspark_connect() -> Any:
"""Get pyspark.sql.connect module (if already imported - else return None)."""
return sys.modules.get("pyspark.sql.connect", None)


def get_sqlframe() -> Any:
"""Get sqlframe module (if already imported - else return None)."""
return sys.modules.get("sqlframe", None)
Expand Down Expand Up @@ -230,6 +236,14 @@ def is_pyspark_dataframe(df: Any) -> TypeIs[pyspark_sql.DataFrame]:
)


def is_pyspark_connect_dataframe(df: Any) -> TypeIs[PySparkConnectDataFrame]:
"""Check whether `df` is a PySpark Connect DataFrame without importing PySpark."""
return bool(
(pyspark_connect := get_pyspark_connect()) is not None
and isinstance(df, pyspark_connect.dataframe.DataFrame)
)


def is_sqlframe_dataframe(df: Any) -> TypeIs[SQLFrameDataFrame]:
"""Check whether `df` is a SQLFrame DataFrame without importing SQLFrame."""
if get_sqlframe() is not None:
Expand Down
Loading
Loading