diff --git a/.github/workflows/typing.yml b/.github/workflows/typing.yml index 8df0cdf264..3264a53891 100644 --- a/.github/workflows/typing.yml +++ b/.github/workflows/typing.yml @@ -32,7 +32,7 @@ jobs: # TODO: add more dependencies/backends incrementally run: | source .venv/bin/activate - uv pip install -e ".[tests, typing, core]" + uv pip install -e ".[tests, typing, core, pyspark, sqlframe]" - name: show-deps run: | source .venv/bin/activate diff --git a/narwhals/_spark_like/dataframe.py b/narwhals/_spark_like/dataframe.py index 6198835d37..be9c100606 100644 --- a/narwhals/_spark_like/dataframe.py +++ b/narwhals/_spark_like/dataframe.py @@ -6,6 +6,7 @@ from typing import Any from typing import Literal from typing import Sequence +from typing import cast from narwhals._spark_like.utils import evaluate_exprs from narwhals._spark_like.utils import native_to_narwhals_dtype @@ -26,7 +27,11 @@ import pyarrow as pa from pyspark.sql import Column from pyspark.sql import DataFrame + from pyspark.sql import Window + from pyspark.sql.session import SparkSession + from sqlframe.base.dataframe import BaseDataFrame as _SQLFrameDataFrame from typing_extensions import Self + from typing_extensions import TypeAlias from narwhals._spark_like.expr import SparkLikeExpr from narwhals._spark_like.group_by import SparkLikeLazyGroupBy @@ -34,11 +39,17 @@ from narwhals.dtypes import DType from narwhals.utils import Version + SQLFrameDataFrame: TypeAlias = _SQLFrameDataFrame[Any, Any, Any, Any, Any] + _NativeDataFrame: TypeAlias = "DataFrame | SQLFrameDataFrame" + +Incomplete: TypeAlias = Any # pragma: no cover +"""Marker for working code that fails type checking.""" + class SparkLikeLazyFrame(CompliantLazyFrame): def __init__( self: Self, - native_dataframe: DataFrame, + native_dataframe: _NativeDataFrame, *, backend_version: tuple[int, ...], version: Version, @@ -54,7 +65,11 @@ def __init__( validate_backend_version(self._implementation, self._backend_version) @property - def _F(self: Self) -> Any: # noqa: N802 + def _F(self: Self): # type: ignore[no-untyped-def] # noqa: ANN202, N802 + if TYPE_CHECKING: + from pyspark.sql import functions + + return functions if self._implementation is Implementation.SQLFRAME: from sqlframe.base.session import _BaseSession @@ -67,7 +82,12 @@ def _F(self: Self) -> Any: # noqa: N802 return functions @property - def _native_dtypes(self: Self) -> Any: + def _native_dtypes(self: Self): # type: ignore[no-untyped-def] # noqa: ANN202 + if TYPE_CHECKING: + from pyspark.sql import types + + return types + if self._implementation is Implementation.SQLFRAME: from sqlframe.base.session import _BaseSession @@ -80,7 +100,7 @@ def _native_dtypes(self: Self) -> Any: return types @property - def _Window(self: Self) -> Any: # noqa: N802 + def _Window(self: Self) -> type[Window]: # noqa: N802 if self._implementation is Implementation.SQLFRAME: from sqlframe.base.session import _BaseSession @@ -94,11 +114,11 @@ def _Window(self: Self) -> Any: # noqa: N802 return Window @property - def _session(self: Self) -> Any: + def _session(self: Self) -> SparkSession: if self._implementation is Implementation.SQLFRAME: - return self._native_frame.session + return cast("SQLFrameDataFrame", self._native_frame).session - return self._native_frame.sparkSession + return cast("DataFrame", self._native_frame).sparkSession def __native_namespace__(self: Self) -> ModuleType: # pragma: no cover return self._implementation.to_native_namespace() @@ -137,10 +157,9 @@ def _collect_to_arrow(self) -> pa.Table: ): import pyarrow as pa # ignore-banned-import + native_frame = cast("DataFrame", self._native_frame) try: - native_pyarrow_frame = pa.Table.from_batches( - self._native_frame._collect_as_arrow() - ) + return pa.Table.from_batches(native_frame._collect_as_arrow()) except ValueError as exc: if "at least one RecordBatch" in str(exc): # Empty dataframe @@ -154,7 +173,7 @@ def _collect_to_arrow(self) -> pa.Table: try: native_dtype = narwhals_to_native_dtype(value, self._version) except Exception as exc: # noqa: BLE001 - native_spark_dtype = self._native_frame.schema[key].dataType + native_spark_dtype = native_frame.schema[key].dataType # If we can't convert the type, just set it to `pa.null`, and warn. # Avoid the warning if we're starting from PySpark's void type. # We can avoid the check when we introduce `nw.Null` dtype. @@ -168,14 +187,13 @@ def _collect_to_arrow(self) -> pa.Table: schema.append((key, pa.null())) else: schema.append((key, native_dtype)) - native_pyarrow_frame = pa.Table.from_pydict( - data, schema=pa.schema(schema) - ) + return pa.Table.from_pydict(data, schema=pa.schema(schema)) else: # pragma: no cover raise else: - native_pyarrow_frame = self._native_frame.toArrow() - return native_pyarrow_frame + # NOTE: See https://github.com/narwhals-dev/narwhals/pull/2051#discussion_r1969224309 + to_arrow: Incomplete = self._native_frame.toArrow + return to_arrow() @property def columns(self: Self) -> list[str]: @@ -246,10 +264,8 @@ def select( if not new_columns: # return empty dataframe, like Polars does - spark_df = self._session.createDataFrame( - [], self._native_dtypes.StructType([]) - ) - + schema = self._native_dtypes.StructType([]) + spark_df = self._session.createDataFrame([], schema) return self._from_native_frame(spark_df) new_columns_list = [col.alias(col_name) for (col_name, col) in new_columns] @@ -272,7 +288,8 @@ def schema(self: Self) -> dict[str, DType]: field.name: native_to_narwhals_dtype( dtype=field.dataType, version=self._version, - spark_types=self._native_dtypes, + # NOTE: Unclear if this is an unsafe hash (https://github.com/narwhals-dev/narwhals/pull/2051#discussion_r1970074662) + spark_types=self._native_dtypes, # pyright: ignore[reportArgumentType] ) for field in self._native_frame.schema } diff --git a/narwhals/_spark_like/expr.py b/narwhals/_spark_like/expr.py index 0bc5725359..41bbe46c5e 100644 --- a/narwhals/_spark_like/expr.py +++ b/narwhals/_spark_like/expr.py @@ -31,7 +31,7 @@ from narwhals.utils import Version -class SparkLikeExpr(CompliantExpr["SparkLikeLazyFrame", "Column"]): +class SparkLikeExpr(CompliantExpr["SparkLikeLazyFrame", "Column"]): # type: ignore[type-var] # (#2044) _depth = 0 # Unused, just for compatibility with CompliantExpr def __init__( @@ -301,7 +301,7 @@ def __or__(self: Self, other: SparkLikeExpr) -> Self: ) def __invert__(self: Self) -> Self: - invert = cast("Callable[..., SparkLikeExpr]", operator.invert) + invert = cast("Callable[..., Column]", operator.invert) return self._from_call(invert, "__invert__") def abs(self: Self) -> Self: diff --git a/narwhals/_spark_like/namespace.py b/narwhals/_spark_like/namespace.py index 74d0841431..07303aee67 100644 --- a/narwhals/_spark_like/namespace.py +++ b/narwhals/_spark_like/namespace.py @@ -8,6 +8,7 @@ from typing import Iterable from typing import Literal from typing import Sequence +from typing import cast from narwhals._expression_parsing import combine_alias_output_names from narwhals._expression_parsing import combine_evaluate_output_names @@ -29,7 +30,7 @@ from narwhals.utils import Version -class SparkLikeNamespace(CompliantNamespace["SparkLikeLazyFrame", "Column"]): +class SparkLikeNamespace(CompliantNamespace["SparkLikeLazyFrame", "Column"]): # type: ignore[type-var] # (#2044) def __init__( self: Self, *, @@ -222,7 +223,7 @@ def concat( *, how: Literal["horizontal", "vertical", "diagonal"], ) -> SparkLikeLazyFrame: - dfs: list[DataFrame] = [item._native_frame for item in items] + dfs = cast("Sequence[DataFrame]", [item._native_frame for item in items]) if how == "horizontal": msg = ( "Horizontal concatenation is not supported for LazyFrame backed by " diff --git a/narwhals/_spark_like/utils.py b/narwhals/_spark_like/utils.py index fd68125bb3..29558c3b32 100644 --- a/narwhals/_spark_like/utils.py +++ b/narwhals/_spark_like/utils.py @@ -11,55 +11,59 @@ from types import ModuleType import pyspark.sql.types as pyspark_types + import sqlframe.base.types as sqlframe_types from pyspark.sql import Column + from typing_extensions import TypeAlias from narwhals._spark_like.dataframe import SparkLikeLazyFrame from narwhals._spark_like.expr import SparkLikeExpr from narwhals.dtypes import DType from narwhals.utils import Version + _NativeDType: TypeAlias = "pyspark_types.DataType | sqlframe_types.DataType" + # NOTE: don't lru_cache this as `ModuleType` isn't hashable def native_to_narwhals_dtype( - dtype: pyspark_types.DataType, - version: Version, - spark_types: ModuleType, + dtype: _NativeDType, version: Version, spark_types: ModuleType ) -> DType: # pragma: no cover dtypes = import_dtypes_module(version=version) + if TYPE_CHECKING: + native = pyspark_types + else: + native = spark_types - if isinstance(dtype, spark_types.DoubleType): + if isinstance(dtype, native.DoubleType): return dtypes.Float64() - if isinstance(dtype, spark_types.FloatType): + if isinstance(dtype, native.FloatType): return dtypes.Float32() - if isinstance(dtype, spark_types.LongType): + if isinstance(dtype, native.LongType): return dtypes.Int64() - if isinstance(dtype, spark_types.IntegerType): + if isinstance(dtype, native.IntegerType): return dtypes.Int32() - if isinstance(dtype, spark_types.ShortType): + if isinstance(dtype, native.ShortType): return dtypes.Int16() - if isinstance(dtype, spark_types.ByteType): + if isinstance(dtype, native.ByteType): return dtypes.Int8() - if isinstance( - dtype, (spark_types.StringType, spark_types.VarcharType, spark_types.CharType) - ): + if isinstance(dtype, (native.StringType, native.VarcharType, native.CharType)): return dtypes.String() - if isinstance(dtype, spark_types.BooleanType): + if isinstance(dtype, native.BooleanType): return dtypes.Boolean() - if isinstance(dtype, spark_types.DateType): + if isinstance(dtype, native.DateType): return dtypes.Date() - if isinstance(dtype, spark_types.TimestampNTZType): + if isinstance(dtype, native.TimestampNTZType): return dtypes.Datetime() - if isinstance(dtype, spark_types.TimestampType): + if isinstance(dtype, native.TimestampType): return dtypes.Datetime(time_zone="UTC") - if isinstance(dtype, spark_types.DecimalType): + if isinstance(dtype, native.DecimalType): return dtypes.Decimal() - if isinstance(dtype, spark_types.ArrayType): + if isinstance(dtype, native.ArrayType): return dtypes.List( inner=native_to_narwhals_dtype( dtype.elementType, version=version, spark_types=spark_types ) ) - if isinstance(dtype, spark_types.StructType): + if isinstance(dtype, native.StructType): return dtypes.Struct( fields=[ dtypes.Field( @@ -78,48 +82,50 @@ def narwhals_to_native_dtype( dtype: DType | type[DType], version: Version, spark_types: ModuleType ) -> pyspark_types.DataType: dtypes = import_dtypes_module(version) + if TYPE_CHECKING: + native = pyspark_types + else: + native = spark_types if isinstance_or_issubclass(dtype, dtypes.Float64): - return spark_types.DoubleType() + return native.DoubleType() if isinstance_or_issubclass(dtype, dtypes.Float32): - return spark_types.FloatType() + return native.FloatType() if isinstance_or_issubclass(dtype, dtypes.Int64): - return spark_types.LongType() + return native.LongType() if isinstance_or_issubclass(dtype, dtypes.Int32): - return spark_types.IntegerType() + return native.IntegerType() if isinstance_or_issubclass(dtype, dtypes.Int16): - return spark_types.ShortType() + return native.ShortType() if isinstance_or_issubclass(dtype, dtypes.Int8): - return spark_types.ByteType() + return native.ByteType() if isinstance_or_issubclass(dtype, dtypes.String): - return spark_types.StringType() + return native.StringType() if isinstance_or_issubclass(dtype, dtypes.Boolean): - return spark_types.BooleanType() + return native.BooleanType() if isinstance_or_issubclass(dtype, dtypes.Date): - return spark_types.DateType() + return native.DateType() if isinstance_or_issubclass(dtype, dtypes.Datetime): dt_time_zone = dtype.time_zone if dt_time_zone is None: - return spark_types.TimestampNTZType() + return native.TimestampNTZType() if dt_time_zone != "UTC": # pragma: no cover msg = f"Only UTC time zone is supported for PySpark, got: {dt_time_zone}" raise ValueError(msg) - return spark_types.TimestampType() + return native.TimestampType() if isinstance_or_issubclass(dtype, (dtypes.List, dtypes.Array)): - return spark_types.ArrayType( + return native.ArrayType( elementType=narwhals_to_native_dtype( - dtype.inner, version=version, spark_types=spark_types + dtype.inner, version=version, spark_types=native ) ) if isinstance_or_issubclass(dtype, dtypes.Struct): # pragma: no cover - return spark_types.StructType( + return native.StructType( fields=[ - spark_types.StructField( + native.StructField( name=field.name, dataType=narwhals_to_native_dtype( - field.dtype, - version=version, - spark_types=spark_types, + field.dtype, version=version, spark_types=native ), ) for field in dtype.fields @@ -147,7 +153,7 @@ def narwhals_to_native_dtype( def evaluate_exprs( df: SparkLikeLazyFrame, /, *exprs: SparkLikeExpr ) -> list[tuple[str, Column]]: - native_results: list[tuple[str, list[Column]]] = [] + native_results: list[tuple[str, Column]] = [] for expr in exprs: native_series_list = expr._call(df) diff --git a/narwhals/dependencies.py b/narwhals/dependencies.py index 6d77259775..071df0a155 100644 --- a/narwhals/dependencies.py +++ b/narwhals/dependencies.py @@ -18,11 +18,11 @@ import polars as pl import pyarrow as pa import pyspark.sql as pyspark_sql - import sqlframe from typing_extensions import TypeGuard from typing_extensions import TypeIs from narwhals._arrow.typing import ArrowChunkedArray + from narwhals._spark_like.dataframe import SQLFrameDataFrame from narwhals.dataframe import DataFrame from narwhals.dataframe import LazyFrame from narwhals.series import Series @@ -231,7 +231,7 @@ def is_pyspark_dataframe(df: Any) -> TypeIs[pyspark_sql.DataFrame]: ) -def is_sqlframe_dataframe(df: Any) -> TypeIs[sqlframe.base.dataframe.BaseDataFrame]: +def is_sqlframe_dataframe(df: Any) -> TypeIs[SQLFrameDataFrame]: """Check whether `df` is a SQLFrame DataFrame without importing SQLFrame.""" return bool( (sqlframe := get_sqlframe()) is not None diff --git a/pyproject.toml b/pyproject.toml index 8a5dc1905a..7c3576d45e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -254,7 +254,6 @@ module = [ "ibis.*", "modin.*", "numpy.*", - "pyspark.*", "sklearn.*", "sqlframe.*", ] @@ -270,6 +269,7 @@ module = [ "*._ibis.*", "*._arrow.*", "*._dask.*", + "*._spark_like.*", ] warn_return_any = false diff --git a/tests/conftest.py b/tests/conftest.py index a0cde090c3..bf5d97b610 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -154,7 +154,7 @@ def pyspark_lazy_constructor() -> Callable[[Any], IntoFrame]: # pragma: no cove ) session = ( - SparkSession.builder.appName("unit-tests") + SparkSession.builder.appName("unit-tests") # pyright: ignore[reportAttributeAccessIssue] .master("local[1]") .config("spark.ui.enabled", "false") # executing one task at a time makes the tests faster @@ -172,7 +172,7 @@ def _constructor(obj: dict[str, list[Any]]) -> IntoFrame: index_col_name = generate_temporary_column_name(n_bytes=8, columns=list(_obj)) _obj[index_col_name] = list(range(len(_obj[next(iter(_obj))]))) - return ( # type: ignore[no-any-return] + return ( session.createDataFrame([*zip(*_obj.values())], schema=[*_obj.keys()]) .repartition(2) .orderBy(index_col_name)