Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,12 @@ jobs:
cache-dependency-glob: "pyproject.toml"
- name: install-reqs
# we are not testing pyspark on Windows here because it is very slow
run: uv pip install -e ".[tests, core, extra, dask, modin]" --system
run: uv pip install -e ".[tests, core, extra, dask, modin, sqlframe]" --system
- name: show-deps
run: uv pip freeze
- name: Run pytest
run: |
pytest tests --cov=narwhals --cov=tests --runslow --cov-fail-under=95 --constructors=pandas,pandas[nullable],pandas[pyarrow],pyarrow,modin[pyarrow],polars[eager],polars[lazy],dask,duckdb
pytest tests --cov=narwhals --cov=tests --runslow --cov-fail-under=95 --constructors=pandas,pandas[nullable],pandas[pyarrow],pyarrow,modin[pyarrow],polars[eager],polars[lazy],dask,duckdb,sqlframe

pytest-full-coverage:
strategy:
Expand All @@ -83,7 +83,7 @@ jobs:
cache-suffix: ${{ matrix.python-version }}
cache-dependency-glob: "pyproject.toml"
- name: install-reqs
run: uv pip install -e ".[tests, core, extra, modin, dask]" --system
run: uv pip install -e ".[tests, core, extra, modin, dask, sqlframe]" --system
- name: install pyspark
run: uv pip install -e ".[pyspark]" --system
# PySpark is not yet available on Python3.12+
Expand Down
146 changes: 114 additions & 32 deletions narwhals/_spark_like/dataframe.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import warnings
from importlib import import_module
from typing import TYPE_CHECKING
from typing import Any
from typing import Literal
Expand All @@ -13,6 +14,7 @@
from narwhals.typing import CompliantLazyFrame
from narwhals.utils import Implementation
from narwhals.utils import check_column_exists
from narwhals.utils import check_column_names_are_unique
from narwhals.utils import find_stacklevel
from narwhals.utils import import_dtypes_module
from narwhals.utils import parse_columns_to_drop
Expand All @@ -23,6 +25,7 @@
from types import ModuleType

import pyarrow as pa
from pyspark.sql import Column
from pyspark.sql import DataFrame
from typing_extensions import Self

Expand All @@ -41,7 +44,10 @@ def __init__(
backend_version: tuple[int, ...],
version: Version,
implementation: Implementation,
validate_column_names: bool,
) -> None:
if validate_column_names:
check_column_names_are_unique(native_dataframe.columns)
Comment on lines +47 to +50
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While pyspark does it out of the box, sqlframe with duckdb backend doesn't.
I think it is good to have it regardless in order to standardize errors and error messages

self._native_frame = native_dataframe
self._backend_version = backend_version
self._implementation = implementation
Expand All @@ -51,33 +57,50 @@ def __init__(
@property
def _F(self: Self) -> Any: # noqa: N802
if self._implementation is Implementation.SQLFRAME:
from sqlframe.duckdb import functions
from sqlframe.base.session import _BaseSession

return import_module(
f"sqlframe.{_BaseSession().execution_dialect_name}.functions"
)

return functions
from pyspark.sql import functions

return functions

@property
def _native_dtypes(self: Self) -> Any:
if self._implementation is Implementation.SQLFRAME:
from sqlframe.duckdb import types
from sqlframe.base.session import _BaseSession

return import_module(
f"sqlframe.{_BaseSession().execution_dialect_name}.types"
)

return types
from pyspark.sql import types

return types

@property
def _Window(self: Self) -> Any: # noqa: N802
if self._implementation is Implementation.SQLFRAME:
from sqlframe.duckdb import Window
from sqlframe.base.session import _BaseSession

_window = import_module(
f"sqlframe.{_BaseSession().execution_dialect_name}.window"
)
return _window.Window

return Window
from pyspark.sql import Window

return Window

@property
def _session(self: Self) -> Any:
if self._implementation is Implementation.SQLFRAME:
return self._native_frame.session

return self._native_frame.sparkSession

def __native_namespace__(self: Self) -> ModuleType: # pragma: no cover
return self._implementation.to_native_namespace()

Expand All @@ -99,14 +122,18 @@ def _change_version(self: Self, version: Version) -> Self:
backend_version=self._backend_version,
version=version,
implementation=self._implementation,
validate_column_names=False,
)

def _from_native_frame(self: Self, df: DataFrame) -> Self:
def _from_native_frame(
self: Self, df: DataFrame, *, validate_column_names: bool = True
) -> Self:
return self.__class__(
df,
backend_version=self._backend_version,
version=self._version,
implementation=self._implementation,
validate_column_names=validate_column_names,
)

def _collect_to_arrow(self) -> pa.Table:
Expand Down Expand Up @@ -205,7 +232,9 @@ def collect(
raise ValueError(msg) # pragma: no cover

def simple_select(self: Self, *column_names: str) -> Self:
return self._from_native_frame(self._native_frame.select(*column_names))
return self._from_native_frame(
self._native_frame.select(*column_names), validate_column_names=False
)

def aggregate(
self: Self,
Expand All @@ -214,7 +243,9 @@ def aggregate(
new_columns = parse_exprs(self, *exprs)

new_columns_list = [col.alias(col_name) for col_name, col in new_columns.items()]
return self._from_native_frame(self._native_frame.agg(*new_columns_list))
return self._from_native_frame(
self._native_frame.agg(*new_columns_list), validate_column_names=False
)

def select(
self: Self,
Expand All @@ -224,17 +255,18 @@ def select(

if not new_columns:
# return empty dataframe, like Polars does
spark_session = self._native_frame.sparkSession
spark_df = spark_session.createDataFrame(
spark_df = self._session.createDataFrame(
[], self._native_dtypes.StructType([])
)

return self._from_native_frame(spark_df)
return self._from_native_frame(spark_df, validate_column_names=False)

new_columns_list = [
col.alias(col_name) for (col_name, col) in new_columns.items()
]
return self._from_native_frame(self._native_frame.select(*new_columns_list))
return self._from_native_frame(
self._native_frame.select(*new_columns_list), validate_column_names=False
)

def with_columns(self: Self, *exprs: SparkLikeExpr) -> Self:
new_columns = parse_exprs(self, *exprs)
Expand All @@ -244,7 +276,7 @@ def filter(self: Self, predicate: SparkLikeExpr) -> Self:
# `[0]` is safe as the predicate's expression only returns a single column
condition = predicate._call(self)[0]
spark_df = self._native_frame.where(condition)
return self._from_native_frame(spark_df)
return self._from_native_frame(spark_df, validate_column_names=False)

@property
def schema(self: Self) -> dict[str, DType]:
Expand All @@ -264,13 +296,13 @@ def drop(self: Self, columns: list[str], strict: bool) -> Self: # noqa: FBT001
columns_to_drop = parse_columns_to_drop(
compliant_frame=self, columns=columns, strict=strict
)
return self._from_native_frame(self._native_frame.drop(*columns_to_drop))
return self._from_native_frame(
self._native_frame.drop(*columns_to_drop), validate_column_names=False
)

def head(self: Self, n: int) -> Self:
spark_session = self._native_frame.sparkSession

return self._from_native_frame(
spark_session.createDataFrame(self._native_frame.take(num=n))
self._native_frame.limit(num=n), validate_column_names=False
)

def group_by(self: Self, *keys: str, drop_null_keys: bool) -> SparkLikeLazyGroupBy:
Expand Down Expand Up @@ -301,10 +333,14 @@ def sort(
)

sort_cols = [sort_f(col) for col, sort_f in zip(by, sort_funcs)]
return self._from_native_frame(self._native_frame.sort(*sort_cols))
return self._from_native_frame(
self._native_frame.sort(*sort_cols), validate_column_names=False
)

def drop_nulls(self: Self, subset: list[str] | None) -> Self:
return self._from_native_frame(self._native_frame.dropna(subset=subset))
return self._from_native_frame(
self._native_frame.dropna(subset=subset), validate_column_names=False
)

def rename(self: Self, mapping: dict[str, str]) -> Self:
rename_mapping = {
Expand All @@ -326,7 +362,9 @@ def unique(
msg = "`LazyFrame.unique` with PySpark backend only supports `keep='any'`."
raise ValueError(msg)
check_column_exists(self.columns, subset)
return self._from_native_frame(self._native_frame.dropDuplicates(subset=subset))
return self._from_native_frame(
self._native_frame.dropDuplicates(subset=subset), validate_column_names=False
)

def join(
self: Self,
Expand Down Expand Up @@ -357,7 +395,7 @@ def join(
for colname in list(set(right_columns).difference(set(right_on or [])))
},
}
other = other_native.select(
other_native = other_native.select(
[self._F.col(old).alias(new) for old, new in rename_mapping.items()]
)

Expand All @@ -375,7 +413,7 @@ def join(
]
)
return self._from_native_frame(
self_native.join(other, on=left_on, how=how).select(col_order)
self_native.join(other_native, on=left_on, how=how).select(col_order)
)

def explode(self: Self, columns: list[str]) -> Self:
Expand All @@ -402,16 +440,51 @@ def explode(self: Self, columns: list[str]) -> Self:
)
raise NotImplementedError(msg)

return self._from_native_frame(
native_frame.select(
*[
self._F.col(col_name).alias(col_name)
if col_name != columns[0]
else self._F.explode_outer(col_name).alias(col_name)
for col_name in column_names
]
if self._implementation.is_pyspark():
return self._from_native_frame(
native_frame.select(
*[
self._F.col(col_name).alias(col_name)
if col_name != columns[0]
else self._F.explode_outer(col_name).alias(col_name)
for col_name in column_names
]
),
validate_column_names=False,
)
)
elif self._implementation.is_sqlframe():
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the same workaround as per duckdb native. As order is not guaranteed, I think that's ok

# Not every sqlframe dialect supports `explode_outer` function
# (see https://github.com/eakmanrq/sqlframe/blob/3cb899c515b101ff4c197d84b34fae490d0ed257/sqlframe/base/functions.py#L2288-L2289)
# therefore we simply explode the array column which will ignore nulls and
# zero sized arrays, and append these specific condition with nulls (to
# match polars behavior).

def null_condition(col_name: str) -> Column:
return self._F.isnull(col_name) | (self._F.array_size(col_name) == 0)

return self._from_native_frame(
native_frame.select(
*[
self._F.col(col_name).alias(col_name)
if col_name != columns[0]
else self._F.explode(col_name).alias(col_name)
for col_name in column_names
]
).union(
native_frame.filter(null_condition(columns[0])).select(
*[
self._F.col(col_name).alias(col_name)
if col_name != columns[0]
else self._F.lit(None).alias(col_name)
for col_name in column_names
]
)
),
validate_column_names=False,
)
else: # pragma: no cover
msg = "Unreachable code, please report an issue at https://github.com/narwhals-dev/narwhals/issues"
raise AssertionError(msg)

def unpivot(
self: Self,
Expand All @@ -420,6 +493,15 @@ def unpivot(
variable_name: str,
value_name: str,
) -> Self:
if self._implementation.is_sqlframe():
if variable_name == "":
msg = "`variable_name` cannot be empty string for sqlframe backend."
raise NotImplementedError(msg)

if value_name == "":
msg = "`value_name` cannot be empty string for sqlframe backend."
raise NotImplementedError(msg)

ids = tuple(self.columns) if index is None else tuple(index)
values = (
tuple(set(self.columns).difference(set(ids))) if on is None else tuple(on)
Expand Down
Loading
Loading