Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
9c420d0
remove pyspark from mypy overrides
EdAbati Feb 16, 2025
cc2b5de
some mypy fixes
EdAbati Feb 16, 2025
59a93f5
Merge remote-tracking branch 'upstream/main' into pyspark-typing
EdAbati Feb 16, 2025
0a33350
Merge remote-tracking branch 'upstream/main' into pyspark-typing
EdAbati Feb 20, 2025
f2c8e6a
ignore type var
EdAbati Feb 20, 2025
e429278
add more tests
EdAbati Feb 20, 2025
8588444
we already test on = none
EdAbati Feb 20, 2025
803cad9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 20, 2025
2e8c9d0
merge upstream
EdAbati Feb 24, 2025
68107d1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 24, 2025
9558c69
fix pyspark unpivot
EdAbati Feb 24, 2025
a3740d4
Merge branch 'main' into pyspark-typing
EdAbati Feb 24, 2025
53c10ee
add pyspark typing ci
EdAbati Feb 24, 2025
f4f3ec7
fix type
EdAbati Feb 24, 2025
c71e3f9
simplify type unpivot
EdAbati Feb 24, 2025
5b4ce3a
make pyright happy
EdAbati Feb 24, 2025
37861aa
Merge branch 'main' into pyspark-typing
dangotbanned Feb 24, 2025
6ff24c2
Merge remote-tracking branch 'upstream/main' into pyspark-typing
EdAbati Feb 25, 2025
e6c08c8
fix invert type
EdAbati Feb 25, 2025
a353ff3
ignore return any
EdAbati Feb 25, 2025
5ec57e2
add pr number to ingore comment
EdAbati Feb 25, 2025
a0fd56c
Merge remote-tracking branch 'upstream/main' into pyspark-typing
EdAbati Feb 25, 2025
1167ba3
chore(typing): Mark `toArrow` as `Incomplete`
dangotbanned Feb 25, 2025
1b03d02
fix dtype typing
EdAbati Feb 25, 2025
2b456f7
Merge remote-tracking branch 'upstream/main' into pr/EdAbati/2051
dangotbanned Feb 25, 2025
c46f2d8
feat(typing): Help `pyright` with `_native_dtypes`
dangotbanned Feb 25, 2025
be91607
Merge remote-tracking branch 'upstream/main' into pr/EdAbati/2051
dangotbanned Feb 25, 2025
305644c
chore(typing): Annotate `_Window`
dangotbanned Feb 25, 2025
b21eab5
feat(typing): help `pyright` with `_F`
dangotbanned Feb 25, 2025
ad4606e
chore(typing): Annotate `_session`
dangotbanned Feb 25, 2025
788bfbd
Merge remote-tracking branch 'upstream/main' into pr/EdAbati/2051
dangotbanned Feb 25, 2025
9a4def1
chore(typing): Add ignore and note for hashing module
dangotbanned Feb 25, 2025
ada5591
ci(typing): install `sqlframe`
dangotbanned Feb 25, 2025
9448db0
fix(typing): resolve `[attr-defined]`
dangotbanned Feb 25, 2025
d04af00
Merge remote-tracking branch 'upstream/main' into pr/EdAbati/2051
dangotbanned Feb 25, 2025
11bde2e
fix(typing): temp resolve impl differences
dangotbanned Feb 25, 2025
93ccbfc
Merge branch 'main' into pyspark-typing
dangotbanned Feb 26, 2025
10ebdc1
Merge branch 'main' into pyspark-typing
dangotbanned Feb 26, 2025
271f726
Merge branch 'main' into pyspark-typing
dangotbanned Feb 27, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/typing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ jobs:
# TODO: add more dependencies/backends incrementally
run: |
source .venv/bin/activate
uv pip install -e ".[tests, typing, core]"
uv pip install -e ".[tests, typing, core, pyspark, sqlframe]"
- name: show-deps
run: |
source .venv/bin/activate
Expand Down
59 changes: 38 additions & 21 deletions narwhals/_spark_like/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Any
from typing import Literal
from typing import Sequence
from typing import cast

from narwhals._spark_like.utils import evaluate_exprs
from narwhals._spark_like.utils import native_to_narwhals_dtype
Expand All @@ -26,19 +27,29 @@
import pyarrow as pa
from pyspark.sql import Column
from pyspark.sql import DataFrame
from pyspark.sql import Window
from pyspark.sql.session import SparkSession
from sqlframe.base.dataframe import BaseDataFrame as _SQLFrameDataFrame
from typing_extensions import Self
from typing_extensions import TypeAlias

from narwhals._spark_like.expr import SparkLikeExpr
from narwhals._spark_like.group_by import SparkLikeLazyGroupBy
from narwhals._spark_like.namespace import SparkLikeNamespace
from narwhals.dtypes import DType
from narwhals.utils import Version

SQLFrameDataFrame: TypeAlias = _SQLFrameDataFrame[Any, Any, Any, Any, Any]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Damn that's a hell of a generic πŸ˜‚

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Damn that's a hell of a generic πŸ˜‚

Ikr!
All the Unknown(s) in this gave me quite the scare
https://github.com/narwhals-dev/narwhals/actions/runs/13527240491/job/37800780977?pr=2051

_NativeDataFrame: TypeAlias = "DataFrame | SQLFrameDataFrame"

Incomplete: TypeAlias = Any # pragma: no cover
"""Marker for working code that fails type checking."""


class SparkLikeLazyFrame(CompliantLazyFrame):
def __init__(
self: Self,
native_dataframe: DataFrame,
native_dataframe: _NativeDataFrame,
*,
backend_version: tuple[int, ...],
version: Version,
Expand All @@ -54,7 +65,11 @@ def __init__(
validate_backend_version(self._implementation, self._backend_version)

@property
def _F(self: Self) -> Any: # noqa: N802
def _F(self: Self): # type: ignore[no-untyped-def] # noqa: ANN202, N802
if TYPE_CHECKING:
from pyspark.sql import functions

return functions
if self._implementation is Implementation.SQLFRAME:
from sqlframe.base.session import _BaseSession

Expand All @@ -67,7 +82,12 @@ def _F(self: Self) -> Any: # noqa: N802
return functions

@property
def _native_dtypes(self: Self) -> Any:
def _native_dtypes(self: Self): # type: ignore[no-untyped-def] # noqa: ANN202
if TYPE_CHECKING:
from pyspark.sql import types

return types
Comment on lines +86 to +89
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any drawback to this? I have barely seen these conditional in non-import blocks

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

@dangotbanned dangotbanned Feb 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🀯🀯

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any drawback to this? I have barely seen these conditional in non-import blocks

I'm not sure I'd say drawback - but I would recommend only trying this as a last resort.
Here - I'm using it to describe something the type system doesn't support yet.

Another use was #2064 (comment) - for a backwards compat bug fix.

I'd much rather do things in a simple way whenever possible

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did a similar thing here to convert these functions into typeguards

https://github.com/zen-xu/pyarrow-stubs/blob/c482d8f146a8cf2cff55cc4060148fe4883f9309/pyarrow-stubs/types.pyi

if TYPE_CHECKING:
from typing import TypeVar
from typing_extensions import TypeAlias
from typing_extensions import TypeIs
from narwhals._arrow.series import ArrowSeries
from narwhals._arrow.typing import ArrowArray
from narwhals._arrow.typing import ArrowChunkedArray
from narwhals._arrow.typing import Incomplete
from narwhals._arrow.typing import StringArray
from narwhals.dtypes import DType
from narwhals.typing import _AnyDArray
from narwhals.utils import Version
# NOTE: stubs don't allow for `ChunkedArray[StructArray]`
# Intended to represent the `.chunks` property storing `list[pa.StructArray]`
ChunkedArrayStructArray: TypeAlias = ArrowChunkedArray
_T = TypeVar("_T")
def is_timestamp(t: Any) -> TypeIs[pa.TimestampType[Any, Any]]: ...
def is_duration(t: Any) -> TypeIs[pa.DurationType[Any]]: ...
def is_list(t: Any) -> TypeIs[pa.ListType[Any]]: ...
def is_large_list(t: Any) -> TypeIs[pa.LargeListType[Any]]: ...
def is_fixed_size_list(t: Any) -> TypeIs[pa.FixedSizeListType[Any, Any]]: ...
def is_dictionary(
t: Any,
) -> TypeIs[pa.DictionaryType[Any, Any, Any]]: ...
def extract_regex(
strings: ArrowChunkedArray,
/,
pattern: str,
*,
options: Any = None,
memory_pool: Any = None,
) -> ChunkedArrayStructArray: ...
else:
from pyarrow.compute import extract_regex
from pyarrow.types import is_dictionary # noqa: F401
from pyarrow.types import is_duration
from pyarrow.types import is_fixed_size_list
from pyarrow.types import is_large_list
from pyarrow.types import is_list
from pyarrow.types import is_timestamp

Copy link
Member

@dangotbanned dangotbanned Feb 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Forgot to mention, but (#2051 (comment)) would be the direction to go in - to avoid needing TYPE_CHECKING blocks.

Keeping everything in one class means you need to either do this - or introduce Union(s) and isinstance checks to resolve them on every usage - to type this correctly.

Currently, we branch on Implementation but we can't represent the relationship between Implementation and these unions:

  • pyspark.sql.types or sqlframe.base.types
  • pyspark.sql.functions or sqlframe.base.functions
  • pyspark.sql.Window or sqlframe.base.window.Window
  • pyspark.sql.dataframe.DataFrame.sparkSession or sqlframe.base.dataframe.BaseDataFrame.session

They're a mix of modules, classes and object attributes


if self._implementation is Implementation.SQLFRAME:
from sqlframe.base.session import _BaseSession

Expand All @@ -80,7 +100,7 @@ def _native_dtypes(self: Self) -> Any:
return types

@property
def _Window(self: Self) -> Any: # noqa: N802
def _Window(self: Self) -> type[Window]: # noqa: N802
if self._implementation is Implementation.SQLFRAME:
from sqlframe.base.session import _BaseSession

Expand All @@ -94,11 +114,11 @@ def _Window(self: Self) -> Any: # noqa: N802
return Window

@property
def _session(self: Self) -> Any:
def _session(self: Self) -> SparkSession:
if self._implementation is Implementation.SQLFRAME:
return self._native_frame.session
return cast("SQLFrameDataFrame", self._native_frame).session

return self._native_frame.sparkSession
return cast("DataFrame", self._native_frame).sparkSession

def __native_namespace__(self: Self) -> ModuleType: # pragma: no cover
return self._implementation.to_native_namespace()
Expand Down Expand Up @@ -137,10 +157,9 @@ def _collect_to_arrow(self) -> pa.Table:
):
import pyarrow as pa # ignore-banned-import

native_frame = cast("DataFrame", self._native_frame)
try:
native_pyarrow_frame = pa.Table.from_batches(
self._native_frame._collect_as_arrow()
)
return pa.Table.from_batches(native_frame._collect_as_arrow())
except ValueError as exc:
if "at least one RecordBatch" in str(exc):
# Empty dataframe
Expand All @@ -154,7 +173,7 @@ def _collect_to_arrow(self) -> pa.Table:
try:
native_dtype = narwhals_to_native_dtype(value, self._version)
except Exception as exc: # noqa: BLE001
native_spark_dtype = self._native_frame.schema[key].dataType
native_spark_dtype = native_frame.schema[key].dataType
# If we can't convert the type, just set it to `pa.null`, and warn.
# Avoid the warning if we're starting from PySpark's void type.
# We can avoid the check when we introduce `nw.Null` dtype.
Expand All @@ -168,14 +187,13 @@ def _collect_to_arrow(self) -> pa.Table:
schema.append((key, pa.null()))
else:
schema.append((key, native_dtype))
native_pyarrow_frame = pa.Table.from_pydict(
data, schema=pa.schema(schema)
)
return pa.Table.from_pydict(data, schema=pa.schema(schema))
else: # pragma: no cover
raise
else:
native_pyarrow_frame = self._native_frame.toArrow()
return native_pyarrow_frame
# NOTE: See https://github.com/narwhals-dev/narwhals/pull/2051#discussion_r1969224309
to_arrow: Incomplete = self._native_frame.toArrow
return to_arrow()

@property
def columns(self: Self) -> list[str]:
Expand Down Expand Up @@ -246,10 +264,8 @@ def select(

if not new_columns:
# return empty dataframe, like Polars does
spark_df = self._session.createDataFrame(
[], self._native_dtypes.StructType([])
)

schema = self._native_dtypes.StructType([])
spark_df = self._session.createDataFrame([], schema)
return self._from_native_frame(spark_df)

new_columns_list = [col.alias(col_name) for (col_name, col) in new_columns]
Expand All @@ -272,7 +288,8 @@ def schema(self: Self) -> dict[str, DType]:
field.name: native_to_narwhals_dtype(
dtype=field.dataType,
version=self._version,
spark_types=self._native_dtypes,
# NOTE: Unclear if this is an unsafe hash (https://github.com/narwhals-dev/narwhals/pull/2051#discussion_r1970074662)
spark_types=self._native_dtypes, # pyright: ignore[reportArgumentType]
)
for field in self._native_frame.schema
}
Expand Down
4 changes: 2 additions & 2 deletions narwhals/_spark_like/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from narwhals.utils import Version


class SparkLikeExpr(CompliantExpr["SparkLikeLazyFrame", "Column"]):
class SparkLikeExpr(CompliantExpr["SparkLikeLazyFrame", "Column"]): # type: ignore[type-var] # (#2044)
_depth = 0 # Unused, just for compatibility with CompliantExpr

def __init__(
Expand Down Expand Up @@ -301,7 +301,7 @@ def __or__(self: Self, other: SparkLikeExpr) -> Self:
)

def __invert__(self: Self) -> Self:
invert = cast("Callable[..., SparkLikeExpr]", operator.invert)
invert = cast("Callable[..., Column]", operator.invert)
return self._from_call(invert, "__invert__")

def abs(self: Self) -> Self:
Expand Down
5 changes: 3 additions & 2 deletions narwhals/_spark_like/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Iterable
from typing import Literal
from typing import Sequence
from typing import cast

from narwhals._expression_parsing import combine_alias_output_names
from narwhals._expression_parsing import combine_evaluate_output_names
Expand All @@ -29,7 +30,7 @@
from narwhals.utils import Version


class SparkLikeNamespace(CompliantNamespace["SparkLikeLazyFrame", "Column"]):
class SparkLikeNamespace(CompliantNamespace["SparkLikeLazyFrame", "Column"]): # type: ignore[type-var] # (#2044)
def __init__(
self: Self,
*,
Expand Down Expand Up @@ -222,7 +223,7 @@ def concat(
*,
how: Literal["horizontal", "vertical", "diagonal"],
) -> SparkLikeLazyFrame:
dfs: list[DataFrame] = [item._native_frame for item in items]
dfs = cast("Sequence[DataFrame]", [item._native_frame for item in items])
if how == "horizontal":
msg = (
"Horizontal concatenation is not supported for LazyFrame backed by "
Expand Down
82 changes: 44 additions & 38 deletions narwhals/_spark_like/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,55 +11,59 @@
from types import ModuleType

import pyspark.sql.types as pyspark_types
import sqlframe.base.types as sqlframe_types
from pyspark.sql import Column
from typing_extensions import TypeAlias

from narwhals._spark_like.dataframe import SparkLikeLazyFrame
from narwhals._spark_like.expr import SparkLikeExpr
from narwhals.dtypes import DType
from narwhals.utils import Version

_NativeDType: TypeAlias = "pyspark_types.DataType | sqlframe_types.DataType"


# NOTE: don't lru_cache this as `ModuleType` isn't hashable
def native_to_narwhals_dtype(
dtype: pyspark_types.DataType,
version: Version,
spark_types: ModuleType,
dtype: _NativeDType, version: Version, spark_types: ModuleType
) -> DType: # pragma: no cover
dtypes = import_dtypes_module(version=version)
if TYPE_CHECKING:
native = pyspark_types
else:
native = spark_types

if isinstance(dtype, spark_types.DoubleType):
if isinstance(dtype, native.DoubleType):
return dtypes.Float64()
if isinstance(dtype, spark_types.FloatType):
if isinstance(dtype, native.FloatType):
return dtypes.Float32()
if isinstance(dtype, spark_types.LongType):
if isinstance(dtype, native.LongType):
return dtypes.Int64()
if isinstance(dtype, spark_types.IntegerType):
if isinstance(dtype, native.IntegerType):
return dtypes.Int32()
if isinstance(dtype, spark_types.ShortType):
if isinstance(dtype, native.ShortType):
return dtypes.Int16()
if isinstance(dtype, spark_types.ByteType):
if isinstance(dtype, native.ByteType):
return dtypes.Int8()
if isinstance(
dtype, (spark_types.StringType, spark_types.VarcharType, spark_types.CharType)
):
if isinstance(dtype, (native.StringType, native.VarcharType, native.CharType)):
return dtypes.String()
if isinstance(dtype, spark_types.BooleanType):
if isinstance(dtype, native.BooleanType):
return dtypes.Boolean()
if isinstance(dtype, spark_types.DateType):
if isinstance(dtype, native.DateType):
return dtypes.Date()
if isinstance(dtype, spark_types.TimestampNTZType):
if isinstance(dtype, native.TimestampNTZType):
return dtypes.Datetime()
if isinstance(dtype, spark_types.TimestampType):
if isinstance(dtype, native.TimestampType):
return dtypes.Datetime(time_zone="UTC")
if isinstance(dtype, spark_types.DecimalType):
if isinstance(dtype, native.DecimalType):
return dtypes.Decimal()
if isinstance(dtype, spark_types.ArrayType):
if isinstance(dtype, native.ArrayType):
return dtypes.List(
inner=native_to_narwhals_dtype(
dtype.elementType, version=version, spark_types=spark_types
)
)
if isinstance(dtype, spark_types.StructType):
if isinstance(dtype, native.StructType):
return dtypes.Struct(
fields=[
dtypes.Field(
Expand All @@ -78,48 +82,50 @@ def narwhals_to_native_dtype(
dtype: DType | type[DType], version: Version, spark_types: ModuleType
) -> pyspark_types.DataType:
dtypes = import_dtypes_module(version)
if TYPE_CHECKING:
native = pyspark_types
else:
native = spark_types

if isinstance_or_issubclass(dtype, dtypes.Float64):
return spark_types.DoubleType()
return native.DoubleType()
if isinstance_or_issubclass(dtype, dtypes.Float32):
return spark_types.FloatType()
return native.FloatType()
if isinstance_or_issubclass(dtype, dtypes.Int64):
return spark_types.LongType()
return native.LongType()
if isinstance_or_issubclass(dtype, dtypes.Int32):
return spark_types.IntegerType()
return native.IntegerType()
if isinstance_or_issubclass(dtype, dtypes.Int16):
return spark_types.ShortType()
return native.ShortType()
if isinstance_or_issubclass(dtype, dtypes.Int8):
return spark_types.ByteType()
return native.ByteType()
if isinstance_or_issubclass(dtype, dtypes.String):
return spark_types.StringType()
return native.StringType()
if isinstance_or_issubclass(dtype, dtypes.Boolean):
return spark_types.BooleanType()
return native.BooleanType()
if isinstance_or_issubclass(dtype, dtypes.Date):
return spark_types.DateType()
return native.DateType()
if isinstance_or_issubclass(dtype, dtypes.Datetime):
dt_time_zone = dtype.time_zone
if dt_time_zone is None:
return spark_types.TimestampNTZType()
return native.TimestampNTZType()
if dt_time_zone != "UTC": # pragma: no cover
msg = f"Only UTC time zone is supported for PySpark, got: {dt_time_zone}"
raise ValueError(msg)
return spark_types.TimestampType()
return native.TimestampType()
if isinstance_or_issubclass(dtype, (dtypes.List, dtypes.Array)):
return spark_types.ArrayType(
return native.ArrayType(
elementType=narwhals_to_native_dtype(
dtype.inner, version=version, spark_types=spark_types
dtype.inner, version=version, spark_types=native
)
)
if isinstance_or_issubclass(dtype, dtypes.Struct): # pragma: no cover
return spark_types.StructType(
return native.StructType(
fields=[
spark_types.StructField(
native.StructField(
name=field.name,
dataType=narwhals_to_native_dtype(
field.dtype,
version=version,
spark_types=spark_types,
field.dtype, version=version, spark_types=native
),
)
for field in dtype.fields
Expand Down Expand Up @@ -147,7 +153,7 @@ def narwhals_to_native_dtype(
def evaluate_exprs(
df: SparkLikeLazyFrame, /, *exprs: SparkLikeExpr
) -> list[tuple[str, Column]]:
native_results: list[tuple[str, list[Column]]] = []
native_results: list[tuple[str, Column]] = []

for expr in exprs:
native_series_list = expr._call(df)
Expand Down
4 changes: 2 additions & 2 deletions narwhals/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@
import polars as pl
import pyarrow as pa
import pyspark.sql as pyspark_sql
import sqlframe
from typing_extensions import TypeGuard
from typing_extensions import TypeIs

from narwhals._arrow.typing import ArrowChunkedArray
from narwhals._spark_like.dataframe import SQLFrameDataFrame
from narwhals.dataframe import DataFrame
from narwhals.dataframe import LazyFrame
from narwhals.series import Series
Expand Down Expand Up @@ -231,7 +231,7 @@ def is_pyspark_dataframe(df: Any) -> TypeIs[pyspark_sql.DataFrame]:
)


def is_sqlframe_dataframe(df: Any) -> TypeIs[sqlframe.base.dataframe.BaseDataFrame]:
def is_sqlframe_dataframe(df: Any) -> TypeIs[SQLFrameDataFrame]:
"""Check whether `df` is a SQLFrame DataFrame without importing SQLFrame."""
return bool(
(sqlframe := get_sqlframe()) is not None
Expand Down
Loading
Loading