diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index 3042f6309d..6bb570e536 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -53,12 +53,12 @@ jobs: cache-dependency-glob: "pyproject.toml" - name: install-reqs # we are not testing pyspark on Windows here because it is very slow - run: uv pip install -e ".[tests, core, extra, dask, modin]" --system + run: uv pip install -e ".[tests, core, extra, dask, modin, sqlframe]" --system - name: show-deps run: uv pip freeze - name: Run pytest run: | - pytest tests --cov=narwhals --cov=tests --runslow --cov-fail-under=95 --constructors=pandas,pandas[nullable],pandas[pyarrow],pyarrow,modin[pyarrow],polars[eager],polars[lazy],dask,duckdb + pytest tests --cov=narwhals --cov=tests --runslow --cov-fail-under=95 --constructors=pandas,pandas[nullable],pandas[pyarrow],pyarrow,modin[pyarrow],polars[eager],polars[lazy],dask,duckdb,sqlframe pytest-full-coverage: strategy: @@ -83,7 +83,7 @@ jobs: cache-suffix: ${{ matrix.python-version }} cache-dependency-glob: "pyproject.toml" - name: install-reqs - run: uv pip install -e ".[tests, core, extra, modin, dask]" --system + run: uv pip install -e ".[tests, core, extra, modin, dask, sqlframe]" --system - name: install pyspark run: uv pip install -e ".[pyspark]" --system # PySpark is not yet available on Python3.12+ diff --git a/narwhals/_spark_like/dataframe.py b/narwhals/_spark_like/dataframe.py index 968ffaf1bc..6ee7a86b42 100644 --- a/narwhals/_spark_like/dataframe.py +++ b/narwhals/_spark_like/dataframe.py @@ -1,6 +1,7 @@ from __future__ import annotations import warnings +from importlib import import_module from typing import TYPE_CHECKING from typing import Any from typing import Literal @@ -13,6 +14,7 @@ from narwhals.typing import CompliantLazyFrame from narwhals.utils import Implementation from narwhals.utils import check_column_exists +from narwhals.utils import check_column_names_are_unique from narwhals.utils import find_stacklevel from narwhals.utils import import_dtypes_module from narwhals.utils import parse_columns_to_drop @@ -23,6 +25,7 @@ from types import ModuleType import pyarrow as pa + from pyspark.sql import Column from pyspark.sql import DataFrame from typing_extensions import Self @@ -41,7 +44,10 @@ def __init__( backend_version: tuple[int, ...], version: Version, implementation: Implementation, + validate_column_names: bool, ) -> None: + if validate_column_names: + check_column_names_are_unique(native_dataframe.columns) self._native_frame = native_dataframe self._backend_version = backend_version self._implementation = implementation @@ -51,9 +57,12 @@ def __init__( @property def _F(self: Self) -> Any: # noqa: N802 if self._implementation is Implementation.SQLFRAME: - from sqlframe.duckdb import functions + from sqlframe.base.session import _BaseSession + + return import_module( + f"sqlframe.{_BaseSession().execution_dialect_name}.functions" + ) - return functions from pyspark.sql import functions return functions @@ -61,9 +70,12 @@ def _F(self: Self) -> Any: # noqa: N802 @property def _native_dtypes(self: Self) -> Any: if self._implementation is Implementation.SQLFRAME: - from sqlframe.duckdb import types + from sqlframe.base.session import _BaseSession + + return import_module( + f"sqlframe.{_BaseSession().execution_dialect_name}.types" + ) - return types from pyspark.sql import types return types @@ -71,13 +83,24 @@ def _native_dtypes(self: Self) -> Any: @property def _Window(self: Self) -> Any: # noqa: N802 if self._implementation is Implementation.SQLFRAME: - from sqlframe.duckdb import Window + from sqlframe.base.session import _BaseSession + + _window = import_module( + f"sqlframe.{_BaseSession().execution_dialect_name}.window" + ) + return _window.Window - return Window from pyspark.sql import Window return Window + @property + def _session(self: Self) -> Any: + if self._implementation is Implementation.SQLFRAME: + return self._native_frame.session + + return self._native_frame.sparkSession + def __native_namespace__(self: Self) -> ModuleType: # pragma: no cover return self._implementation.to_native_namespace() @@ -99,14 +122,18 @@ def _change_version(self: Self, version: Version) -> Self: backend_version=self._backend_version, version=version, implementation=self._implementation, + validate_column_names=False, ) - def _from_native_frame(self: Self, df: DataFrame) -> Self: + def _from_native_frame( + self: Self, df: DataFrame, *, validate_column_names: bool = True + ) -> Self: return self.__class__( df, backend_version=self._backend_version, version=self._version, implementation=self._implementation, + validate_column_names=validate_column_names, ) def _collect_to_arrow(self) -> pa.Table: @@ -205,7 +232,9 @@ def collect( raise ValueError(msg) # pragma: no cover def simple_select(self: Self, *column_names: str) -> Self: - return self._from_native_frame(self._native_frame.select(*column_names)) + return self._from_native_frame( + self._native_frame.select(*column_names), validate_column_names=False + ) def aggregate( self: Self, @@ -214,7 +243,9 @@ def aggregate( new_columns = parse_exprs(self, *exprs) new_columns_list = [col.alias(col_name) for col_name, col in new_columns.items()] - return self._from_native_frame(self._native_frame.agg(*new_columns_list)) + return self._from_native_frame( + self._native_frame.agg(*new_columns_list), validate_column_names=False + ) def select( self: Self, @@ -224,17 +255,18 @@ def select( if not new_columns: # return empty dataframe, like Polars does - spark_session = self._native_frame.sparkSession - spark_df = spark_session.createDataFrame( + spark_df = self._session.createDataFrame( [], self._native_dtypes.StructType([]) ) - return self._from_native_frame(spark_df) + return self._from_native_frame(spark_df, validate_column_names=False) new_columns_list = [ col.alias(col_name) for (col_name, col) in new_columns.items() ] - return self._from_native_frame(self._native_frame.select(*new_columns_list)) + return self._from_native_frame( + self._native_frame.select(*new_columns_list), validate_column_names=False + ) def with_columns(self: Self, *exprs: SparkLikeExpr) -> Self: new_columns = parse_exprs(self, *exprs) @@ -244,7 +276,7 @@ def filter(self: Self, predicate: SparkLikeExpr) -> Self: # `[0]` is safe as the predicate's expression only returns a single column condition = predicate._call(self)[0] spark_df = self._native_frame.where(condition) - return self._from_native_frame(spark_df) + return self._from_native_frame(spark_df, validate_column_names=False) @property def schema(self: Self) -> dict[str, DType]: @@ -264,13 +296,13 @@ def drop(self: Self, columns: list[str], strict: bool) -> Self: # noqa: FBT001 columns_to_drop = parse_columns_to_drop( compliant_frame=self, columns=columns, strict=strict ) - return self._from_native_frame(self._native_frame.drop(*columns_to_drop)) + return self._from_native_frame( + self._native_frame.drop(*columns_to_drop), validate_column_names=False + ) def head(self: Self, n: int) -> Self: - spark_session = self._native_frame.sparkSession - return self._from_native_frame( - spark_session.createDataFrame(self._native_frame.take(num=n)) + self._native_frame.limit(num=n), validate_column_names=False ) def group_by(self: Self, *keys: str, drop_null_keys: bool) -> SparkLikeLazyGroupBy: @@ -301,10 +333,14 @@ def sort( ) sort_cols = [sort_f(col) for col, sort_f in zip(by, sort_funcs)] - return self._from_native_frame(self._native_frame.sort(*sort_cols)) + return self._from_native_frame( + self._native_frame.sort(*sort_cols), validate_column_names=False + ) def drop_nulls(self: Self, subset: list[str] | None) -> Self: - return self._from_native_frame(self._native_frame.dropna(subset=subset)) + return self._from_native_frame( + self._native_frame.dropna(subset=subset), validate_column_names=False + ) def rename(self: Self, mapping: dict[str, str]) -> Self: rename_mapping = { @@ -326,7 +362,9 @@ def unique( msg = "`LazyFrame.unique` with PySpark backend only supports `keep='any'`." raise ValueError(msg) check_column_exists(self.columns, subset) - return self._from_native_frame(self._native_frame.dropDuplicates(subset=subset)) + return self._from_native_frame( + self._native_frame.dropDuplicates(subset=subset), validate_column_names=False + ) def join( self: Self, @@ -357,7 +395,7 @@ def join( for colname in list(set(right_columns).difference(set(right_on or []))) }, } - other = other_native.select( + other_native = other_native.select( [self._F.col(old).alias(new) for old, new in rename_mapping.items()] ) @@ -375,7 +413,7 @@ def join( ] ) return self._from_native_frame( - self_native.join(other, on=left_on, how=how).select(col_order) + self_native.join(other_native, on=left_on, how=how).select(col_order) ) def explode(self: Self, columns: list[str]) -> Self: @@ -402,16 +440,51 @@ def explode(self: Self, columns: list[str]) -> Self: ) raise NotImplementedError(msg) - return self._from_native_frame( - native_frame.select( - *[ - self._F.col(col_name).alias(col_name) - if col_name != columns[0] - else self._F.explode_outer(col_name).alias(col_name) - for col_name in column_names - ] + if self._implementation.is_pyspark(): + return self._from_native_frame( + native_frame.select( + *[ + self._F.col(col_name).alias(col_name) + if col_name != columns[0] + else self._F.explode_outer(col_name).alias(col_name) + for col_name in column_names + ] + ), + validate_column_names=False, ) - ) + elif self._implementation.is_sqlframe(): + # Not every sqlframe dialect supports `explode_outer` function + # (see https://github.com/eakmanrq/sqlframe/blob/3cb899c515b101ff4c197d84b34fae490d0ed257/sqlframe/base/functions.py#L2288-L2289) + # therefore we simply explode the array column which will ignore nulls and + # zero sized arrays, and append these specific condition with nulls (to + # match polars behavior). + + def null_condition(col_name: str) -> Column: + return self._F.isnull(col_name) | (self._F.array_size(col_name) == 0) + + return self._from_native_frame( + native_frame.select( + *[ + self._F.col(col_name).alias(col_name) + if col_name != columns[0] + else self._F.explode(col_name).alias(col_name) + for col_name in column_names + ] + ).union( + native_frame.filter(null_condition(columns[0])).select( + *[ + self._F.col(col_name).alias(col_name) + if col_name != columns[0] + else self._F.lit(None).alias(col_name) + for col_name in column_names + ] + ) + ), + validate_column_names=False, + ) + else: # pragma: no cover + msg = "Unreachable code, please report an issue at https://github.com/narwhals-dev/narwhals/issues" + raise AssertionError(msg) def unpivot( self: Self, @@ -420,6 +493,15 @@ def unpivot( variable_name: str, value_name: str, ) -> Self: + if self._implementation.is_sqlframe(): + if variable_name == "": + msg = "`variable_name` cannot be empty string for sqlframe backend." + raise NotImplementedError(msg) + + if value_name == "": + msg = "`value_name` cannot be empty string for sqlframe backend." + raise NotImplementedError(msg) + ids = tuple(self.columns) if index is None else tuple(index) values = ( tuple(set(self.columns).difference(set(ids))) if on is None else tuple(on) diff --git a/narwhals/_spark_like/expr.py b/narwhals/_spark_like/expr.py index dbc833d3a5..0bc5725359 100644 --- a/narwhals/_spark_like/expr.py +++ b/narwhals/_spark_like/expr.py @@ -1,6 +1,7 @@ from __future__ import annotations import operator +from importlib import import_module from typing import TYPE_CHECKING from typing import Any from typing import Callable @@ -15,6 +16,7 @@ from narwhals._spark_like.expr_str import SparkLikeExprStringNamespace from narwhals._spark_like.utils import maybe_evaluate_expr from narwhals._spark_like.utils import narwhals_to_native_dtype +from narwhals.dependencies import get_pyspark from narwhals.typing import CompliantExpr from narwhals.utils import Implementation from narwhals.utils import parse_version @@ -75,31 +77,41 @@ def func(df: SparkLikeLazyFrame) -> Sequence[Column]: ) @property - def _F(self) -> Any: # noqa: N802 + def _F(self: Self) -> Any: # noqa: N802 if self._implementation is Implementation.SQLFRAME: - from sqlframe.duckdb import functions + from sqlframe.base.session import _BaseSession + + return import_module( + f"sqlframe.{_BaseSession().execution_dialect_name}.functions" + ) - return functions from pyspark.sql import functions return functions @property - def _native_types(self) -> Any: + def _native_dtypes(self: Self) -> Any: if self._implementation is Implementation.SQLFRAME: - from sqlframe.duckdb import types + from sqlframe.base.session import _BaseSession + + return import_module( + f"sqlframe.{_BaseSession().execution_dialect_name}.types" + ) - return types from pyspark.sql import types return types @property - def _Window(self) -> Any: # noqa: N802 + def _Window(self: Self) -> Any: # noqa: N802 if self._implementation is Implementation.SQLFRAME: - from sqlframe.duckdb import Window + from sqlframe.base.session import _BaseSession + + _window = import_module( + f"sqlframe.{_BaseSession().execution_dialect_name}.window" + ) + return _window.Window - return Window from pyspark.sql import Window return Window @@ -321,7 +333,7 @@ def any(self: Self) -> Self: def cast(self: Self, dtype: DType | type[DType]) -> Self: def _cast(_input: Column) -> Column: spark_dtype = narwhals_to_native_dtype( - dtype, self._version, self._native_types + dtype, self._version, self._native_dtypes ) return _input.cast(spark_dtype) @@ -338,9 +350,11 @@ def mean(self: Self) -> Self: def median(self: Self) -> Self: def _median(_input: Column) -> Column: - import pyspark # ignore-banned-import - - if parse_version(pyspark) < (3, 4): + if ( + self._implementation.is_pyspark() + and (pyspark := get_pyspark()) is not None + and parse_version(pyspark) < (3, 4) + ): # Use percentile_approx with default accuracy parameter (10000) return self._F.percentile_approx(_input.cast("double"), 0.5) @@ -422,7 +436,7 @@ def _is_finite(_input: Column) -> Column: def is_in(self: Self, values: Sequence[Any]) -> Self: def _is_in(_input: Column) -> Column: - return _input.isin(values) + return _input.isin(values) if values else self._F.lit(False) # noqa: FBT003 return self._from_call(_is_in, "is_in") @@ -452,7 +466,7 @@ def skew(self: Self) -> Self: def n_unique(self: Self) -> Self: def _n_unique(_input: Column) -> Column: return self._F.count_distinct(_input) + self._F.max( - self._F.isnull(_input).cast(self._native_types.IntegerType()) + self._F.isnull(_input).cast(self._native_dtypes.IntegerType()) ) return self._from_call(_n_unique, "n_unique") diff --git a/narwhals/_spark_like/namespace.py b/narwhals/_spark_like/namespace.py index 1b68d0734a..d64401e0fa 100644 --- a/narwhals/_spark_like/namespace.py +++ b/narwhals/_spark_like/namespace.py @@ -246,6 +246,7 @@ def concat( backend_version=self._backend_version, version=self._version, implementation=self._implementation, + validate_column_names=False, ) if how == "diagonal": @@ -256,6 +257,7 @@ def concat( backend_version=self._backend_version, version=self._version, implementation=self._implementation, + validate_column_names=False, ) raise NotImplementedError diff --git a/narwhals/translate.py b/narwhals/translate.py index 5dd7c83931..e545c2726c 100644 --- a/narwhals/translate.py +++ b/narwhals/translate.py @@ -738,6 +738,7 @@ def _from_native_impl( # noqa: PLR0915 backend_version=parse_version(get_pyspark()), version=version, implementation=Implementation.PYSPARK, + validate_column_names=True, ), level="lazy", ) @@ -760,6 +761,7 @@ def _from_native_impl( # noqa: PLR0915 backend_version=backend_version, version=version, implementation=Implementation.SQLFRAME, + validate_column_names=True, ), level="lazy", ) diff --git a/narwhals/utils.py b/narwhals/utils.py index 172fa5d96b..da8eadb80e 100644 --- a/narwhals/utils.py +++ b/narwhals/utils.py @@ -26,6 +26,7 @@ from narwhals.dependencies import get_polars from narwhals.dependencies import get_pyarrow from narwhals.dependencies import get_pyspark_sql +from narwhals.dependencies import get_sqlframe from narwhals.dependencies import is_cudf_series from narwhals.dependencies import is_modin_series from narwhals.dependencies import is_pandas_dataframe @@ -157,6 +158,7 @@ def from_native_namespace( get_dask_dataframe(): Implementation.DASK, get_duckdb(): Implementation.DUCKDB, get_ibis(): Implementation.IBIS, + get_sqlframe(): Implementation.SQLFRAME, } return mapping.get(native_namespace, Implementation.UNKNOWN) @@ -182,6 +184,7 @@ def from_string( "dask": Implementation.DASK, "duckdb": Implementation.DUCKDB, "ibis": Implementation.IBIS, + "sqlframe": Implementation.SQLFRAME, } return mapping.get(backend_name, Implementation.UNKNOWN) @@ -244,6 +247,12 @@ def to_native_namespace(self: Self) -> ModuleType: import duckdb # ignore-banned-import return duckdb + + if self is Implementation.SQLFRAME: + import sqlframe # ignore-banned-import + + return sqlframe + msg = "Not supported Implementation" # pragma: no cover raise AssertionError(msg) @@ -283,6 +292,22 @@ def is_pandas_like(self: Self) -> bool: Implementation.CUDF, } + def is_spark_like(self: Self) -> bool: + """Return whether implementation is pyspark or sqlframe. + + Returns: + Boolean. + + Examples: + >>> import pandas as pd + >>> import narwhals as nw + >>> df_native = pd.DataFrame({"a": [1, 2, 3]}) + >>> df = nw.from_native(df_native) + >>> df.implementation.is_spark_like() + False + """ + return self in {Implementation.PYSPARK, Implementation.SQLFRAME} + def is_polars(self: Self) -> bool: """Return whether implementation is Polars. @@ -411,6 +436,22 @@ def is_ibis(self: Self) -> bool: """ return self is Implementation.IBIS # pragma: no cover + def is_sqlframe(self: Self) -> bool: + """Return whether implementation is SQLFrame. + + Returns: + Boolean. + + Examples: + >>> import polars as pl + >>> import narwhals as nw + >>> df_native = pl.DataFrame({"a": [1, 2, 3]}) + >>> df = nw.from_native(df_native) + >>> df.implementation.is_sqlframe() + False + """ + return self is Implementation.SQLFRAME # pragma: no cover + MIN_VERSIONS: dict[Implementation, tuple[int, ...]] = { Implementation.PANDAS: (0, 25, 3), @@ -422,7 +463,7 @@ def is_ibis(self: Self) -> bool: Implementation.DASK: (2024, 8), Implementation.DUCKDB: (1,), Implementation.IBIS: (6,), - Implementation.SQLFRAME: (3, 14, 2), + Implementation.SQLFRAME: (3, 22, 0), } diff --git a/pyproject.toml b/pyproject.toml index 44984eb976..649a076a42 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,7 @@ polars = ["polars>=0.20.3"] dask = ["dask[dataframe]>=2024.8"] duckdb = ["duckdb>=1.0"] ibis = ["ibis-framework>=6.0.0", "rich", "packaging", "pyarrow_hotfix"] +sqlframe = ["sqlframe>=3.22.0"] tests = [ "covdefaults", "pytest", @@ -201,7 +202,8 @@ filterwarnings = [ 'ignore:.*is_datetime64tz_dtype is deprecated and will be removed in a future version.*:DeprecationWarning:pyspark', # Warning raised by PyArrow nightly just by importing pandas 'ignore:.*Python binding for RankQuantileOptions not exposed:RuntimeWarning:pyarrow', - 'ignore:.*pandas only supports SQLAlchemy:UserWarning' + 'ignore:.*pandas only supports SQLAlchemy:UserWarning', + 'ignore:.*numpy.core is deprecated and has been renamed to numpy._core.*:DeprecationWarning:sqlframe', ] xfail_strict = true markers = ["slow: marks tests as slow (deselect with '-m \"not slow\"')"] diff --git a/tests/conftest.py b/tests/conftest.py index 9ee92bcbd6..a0cde090c3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -184,7 +184,7 @@ def _constructor(obj: dict[str, list[Any]]) -> IntoFrame: def sqlframe_pyspark_lazy_constructor( obj: dict[str, Any], -) -> Callable[[Any], IntoFrame]: # pragma: no cover +) -> IntoFrame: # pragma: no cover from sqlframe.duckdb import DuckDBSession session = DuckDBSession() @@ -208,9 +208,7 @@ def sqlframe_pyspark_lazy_constructor( "polars[lazy]": polars_lazy_constructor, "duckdb": duckdb_lazy_constructor, "pyspark": pyspark_lazy_constructor, # type: ignore[dict-item] - # We've reported several bugs to sqlframe - once they address - # them, we can start testing them as part of our CI. - # "sqlframe": sqlframe_pyspark_lazy_constructor, # noqa: ERA001 + "sqlframe": sqlframe_pyspark_lazy_constructor, } GPU_CONSTRUCTORS: dict[str, Callable[[Any], IntoFrame]] = {"cudf": cudf_constructor} diff --git a/tests/expr_and_series/cast_test.py b/tests/expr_and_series/cast_test.py index 1fa603bd81..1f512cde40 100644 --- a/tests/expr_and_series/cast_test.py +++ b/tests/expr_and_series/cast_test.py @@ -54,7 +54,7 @@ "p": nw.Int64, } -SPARK_INCOMPATIBLE_COLUMNS = {"e", "f", "g", "h", "l", "o", "p"} +SPARK_LIKE_INCOMPATIBLE_COLUMNS = {"e", "f", "g", "h", "l", "o", "p"} DUCKDB_INCOMPATIBLE_COLUMNS = {"l", "o", "p"} @@ -72,7 +72,7 @@ def test_cast( request.applymarker(pytest.mark.xfail) if "pyspark" in str(constructor): - incompatible_columns = SPARK_INCOMPATIBLE_COLUMNS # pragma: no cover + incompatible_columns = SPARK_LIKE_INCOMPATIBLE_COLUMNS # pragma: no cover elif "duckdb" in str(constructor): incompatible_columns = DUCKDB_INCOMPATIBLE_COLUMNS # pragma: no cover else: @@ -185,7 +185,7 @@ def test_cast_raises_for_unknown_dtype( request.applymarker(pytest.mark.xfail) if "pyspark" in str(constructor): - incompatible_columns = SPARK_INCOMPATIBLE_COLUMNS # pragma: no cover + incompatible_columns = SPARK_LIKE_INCOMPATIBLE_COLUMNS # pragma: no cover else: incompatible_columns = set() @@ -236,7 +236,9 @@ def test_cast_datetime_tz_aware( def test_cast_struct(request: pytest.FixtureRequest, constructor: Constructor) -> None: - if any(backend in str(constructor) for backend in ("dask", "modin", "cudf")): + if any( + backend in str(constructor) for backend in ("dask", "modin", "cudf", "sqlframe") + ): request.applymarker(pytest.mark.xfail) if "pandas" in str(constructor) and PANDAS_VERSION < (2, 2): @@ -254,8 +256,9 @@ def test_cast_struct(request: pytest.FixtureRequest, constructor: Constructor) - if "spark" in str(constructor): # pragma: no cover # Special handling for pyspark as it natively maps the input to # a column of type MAP - import pyspark.sql.functions as F # noqa: N812 - import pyspark.sql.types as T # noqa: N812 + _tmp_nw_compliant_frame = nw.from_native(native_df)._compliant_frame + F = _tmp_nw_compliant_frame._F # noqa: N806 + T = _tmp_nw_compliant_frame._native_dtypes # noqa: N806 native_df = native_df.withColumn( # type: ignore[union-attr] "a", diff --git a/tests/expr_and_series/str/to_uppercase_to_lowercase_test.py b/tests/expr_and_series/str/to_uppercase_to_lowercase_test.py index c70b5f8445..28aaa71373 100644 --- a/tests/expr_and_series/str/to_uppercase_to_lowercase_test.py +++ b/tests/expr_and_series/str/to_uppercase_to_lowercase_test.py @@ -37,6 +37,7 @@ def test_str_to_uppercase( "pyarrow_table_constructor", "modin_pyarrow_constructor", "duckdb_lazy_constructor", + "sqlframe_pyspark_lazy_constructor", } or ("dask" in str(constructor) and PYARROW_VERSION >= (12,)) ): @@ -44,7 +45,6 @@ def test_str_to_uppercase( # since the pyarrow backend will convert # smaller cap 'ß' to upper cap 'ẞ' instead of 'SS' request.applymarker(pytest.mark.xfail) - df = nw.from_native(constructor(data)) result_frame = df.select(nw.col("a").str.to_uppercase()) diff --git a/tests/expr_and_series/unary_test.py b/tests/expr_and_series/unary_test.py index 280ace8258..40c9a49ef0 100644 --- a/tests/expr_and_series/unary_test.py +++ b/tests/expr_and_series/unary_test.py @@ -77,7 +77,11 @@ def test_unary_series(constructor_eager: ConstructorEager) -> None: assert_equal_data(result, expected) -def test_unary_two_elements(constructor: Constructor) -> None: +def test_unary_two_elements( + constructor: Constructor, request: pytest.FixtureRequest +) -> None: + if "sqlframe" in str(constructor): + request.applymarker(pytest.mark.xfail) data = {"a": [1, 2], "b": [2, 10], "c": [2.0, None]} result = nw.from_native(constructor(data)).select( a_nunique=nw.col("a").n_unique(), @@ -123,7 +127,7 @@ def test_unary_two_elements_series(constructor_eager: ConstructorEager) -> None: def test_unary_one_element( constructor: Constructor, request: pytest.FixtureRequest ) -> None: - if "pyspark" in str(constructor): + if "pyspark" in str(constructor) and "sqlframe" not in str(constructor): request.applymarker(pytest.mark.xfail) data = {"a": [1], "b": [2], "c": [None]} # Dask runs into a divide by zero RuntimeWarning for 1 element skew. diff --git a/tests/frame/join_test.py b/tests/frame/join_test.py index 7105863843..41a9823d6b 100644 --- a/tests/frame/join_test.py +++ b/tests/frame/join_test.py @@ -669,14 +669,14 @@ def test_joinasof_by_exceptions(constructor: Constructor) -> None: def test_join_duplicate_column_names( constructor: Constructor, request: pytest.FixtureRequest ) -> None: - if "polars" in str(constructor): - # https://github.com/pola-rs/polars/issues/21048 - request.applymarker(pytest.mark.xfail) - if "cudf" in str(constructor): + if ( + "polars" in str(constructor) # https://github.com/pola-rs/polars/issues/21048 + or "cudf" in str(constructor) # TODO(unassigned): cudf doesn't raise here for some reason, # need to investigate. + ): request.applymarker(pytest.mark.xfail) - if "pyspark" in str(constructor): + if "pyspark" in str(constructor) and "sqlframe" not in str(constructor): from pyspark.errors import AnalysisException exception = AnalysisException diff --git a/tests/frame/select_test.py b/tests/frame/select_test.py index ee46a4a67e..f329723262 100644 --- a/tests/frame/select_test.py +++ b/tests/frame/select_test.py @@ -30,7 +30,7 @@ def test_select(constructor: Constructor) -> None: def test_empty_select(constructor: Constructor, request: pytest.FixtureRequest) -> None: - if "duckdb" in str(constructor): + if "duckdb" in str(constructor) or "sqlframe" in str(constructor): request.applymarker(pytest.mark.xfail) result = nw.from_native(constructor({"a": [1, 2, 3]})).lazy().select() assert result.collect().shape == (0, 0) diff --git a/tests/frame/unique_test.py b/tests/frame/unique_test.py index 8e1b71de0e..61d64200f5 100644 --- a/tests/frame/unique_test.py +++ b/tests/frame/unique_test.py @@ -39,9 +39,7 @@ def test_unique( "last", }: context: Any = pytest.raises(ValueError, match="row order") - elif ( - keep == "none" and df.implementation is nw.Implementation.PYSPARK - ): # pragma: no cover + elif keep == "none" and df.implementation.is_spark_like(): # pragma: no cover context = pytest.raises( ValueError, match="`LazyFrame.unique` with PySpark backend only supports `keep='any'`.", diff --git a/tests/frame/unpivot_test.py b/tests/frame/unpivot_test.py index e1573c927d..706defa92e 100644 --- a/tests/frame/unpivot_test.py +++ b/tests/frame/unpivot_test.py @@ -90,7 +90,15 @@ def test_unpivot_var_value_names( ) -> None: context = ( pytest.raises(NotImplementedError) - if ("duckdb" in str(constructor) and any([variable_name == "", value_name == ""])) + if ( + any([variable_name == "", value_name == ""]) + and ( + "duckdb" in str(constructor) + # This might depend from the dialect we use in sqlframe. + # Since for now we use only duckdb, we need to xfail it + or "sqlframe" in str(constructor) + ) + ) else does_not_raise() ) diff --git a/tests/selectors_test.py b/tests/selectors_test.py index c7703c5d46..b9e06f7e96 100644 --- a/tests/selectors_test.py +++ b/tests/selectors_test.py @@ -221,7 +221,7 @@ def test_set_ops( expected: list[str], request: pytest.FixtureRequest, ) -> None: - if "duckdb" in str(constructor) and not expected: + if ("duckdb" in str(constructor) or "sqlframe" in str(constructor)) and not expected: request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) result = df.select(selector).collect_schema().names()