diff --git a/narwhals/_arrow/expr.py b/narwhals/_arrow/expr.py index d3378ea8eb..baeb9a9a40 100644 --- a/narwhals/_arrow/expr.py +++ b/narwhals/_arrow/expr.py @@ -30,7 +30,7 @@ from narwhals.utils import Version -class ArrowExpr(CompliantExpr[ArrowSeries]): +class ArrowExpr(CompliantExpr["ArrowDataFrame", ArrowSeries]): _implementation: Implementation = Implementation.PYARROW def __init__( diff --git a/narwhals/_arrow/namespace.py b/narwhals/_arrow/namespace.py index f11e69af8e..4ae03060d7 100644 --- a/narwhals/_arrow/namespace.py +++ b/narwhals/_arrow/namespace.py @@ -40,7 +40,7 @@ from narwhals.utils import Version -class ArrowNamespace(CompliantNamespace[ArrowSeries]): +class ArrowNamespace(CompliantNamespace[ArrowDataFrame, ArrowSeries]): def _create_expr_from_callable( self: Self, func: Callable[[ArrowDataFrame], Sequence[ArrowSeries]], diff --git a/narwhals/_dask/expr.py b/narwhals/_dask/expr.py index a0af9e64e0..434eac540a 100644 --- a/narwhals/_dask/expr.py +++ b/narwhals/_dask/expr.py @@ -35,7 +35,7 @@ from narwhals.utils import Version -class DaskExpr(CompliantExpr["dx.Series"]): +class DaskExpr(CompliantExpr["DaskLazyFrame", "dx.Series"]): _implementation: Implementation = Implementation.DASK def __init__( diff --git a/narwhals/_dask/group_by.py b/narwhals/_dask/group_by.py index 39f79145e3..13b57796a2 100644 --- a/narwhals/_dask/group_by.py +++ b/narwhals/_dask/group_by.py @@ -107,7 +107,7 @@ def _from_native_frame(self: Self, df: DaskLazyFrame) -> DaskLazyFrame: def agg_dask( df: DaskLazyFrame, grouped: Any, - exprs: Sequence[CompliantExpr[dx.Series]], + exprs: Sequence[CompliantExpr[DaskLazyFrame, dx.Series]], keys: list[str], from_dataframe: Callable[[Any], DaskLazyFrame], ) -> DaskLazyFrame: diff --git a/narwhals/_dask/namespace.py b/narwhals/_dask/namespace.py index 1ef2d7e843..53ef7dbc8b 100644 --- a/narwhals/_dask/namespace.py +++ b/narwhals/_dask/namespace.py @@ -38,7 +38,7 @@ import dask_expr as dx -class DaskNamespace(CompliantNamespace["dx.Series"]): +class DaskNamespace(CompliantNamespace[DaskLazyFrame, "dx.Series"]): @property def selectors(self: Self) -> DaskSelectorNamespace: return DaskSelectorNamespace(self) diff --git a/narwhals/_duckdb/expr.py b/narwhals/_duckdb/expr.py index 0dc9aab63f..45f614489a 100644 --- a/narwhals/_duckdb/expr.py +++ b/narwhals/_duckdb/expr.py @@ -33,7 +33,7 @@ from narwhals.utils import Version -class DuckDBExpr(CompliantExpr["duckdb.Expression"]): # type: ignore[type-var] +class DuckDBExpr(CompliantExpr["DuckDBLazyFrame", "duckdb.Expression"]): # type: ignore[type-var] _implementation = Implementation.DUCKDB _depth = 0 # Unused, just for compatibility with CompliantExpr diff --git a/narwhals/_duckdb/namespace.py b/narwhals/_duckdb/namespace.py index 2f50a70724..bc5de2fd2c 100644 --- a/narwhals/_duckdb/namespace.py +++ b/narwhals/_duckdb/namespace.py @@ -34,7 +34,7 @@ from narwhals.utils import Version -class DuckDBNamespace(CompliantNamespace["duckdb.Expression"]): # type: ignore[type-var] +class DuckDBNamespace(CompliantNamespace["DuckDBLazyFrame", "duckdb.Expression"]): # type: ignore[type-var] def __init__( self: Self, *, backend_version: tuple[int, ...], version: Version ) -> None: diff --git a/narwhals/_expression_parsing.py b/narwhals/_expression_parsing.py index ba0c81f9cb..0e91c02d09 100644 --- a/narwhals/_expression_parsing.py +++ b/narwhals/_expression_parsing.py @@ -12,7 +12,6 @@ from typing import Sequence from typing import TypedDict from typing import TypeVar -from typing import Union from typing import overload from narwhals.dependencies import is_narwhals_series @@ -30,6 +29,7 @@ from narwhals.expr import Expr from narwhals.typing import CompliantDataFrame from narwhals.typing import CompliantExpr + from narwhals.typing import CompliantFrameT_contra from narwhals.typing import CompliantLazyFrame from narwhals.typing import CompliantNamespace from narwhals.typing import CompliantSeries @@ -37,9 +37,6 @@ from narwhals.typing import IntoExpr from narwhals.typing import _1DArray - ArrowOrPandasLikeExpr = TypeVar( - "ArrowOrPandasLikeExpr", bound=Union[ArrowExpr, PandasLikeExpr] - ) PandasLikeExprT = TypeVar("PandasLikeExprT", bound=PandasLikeExpr) ArrowExprT = TypeVar("ArrowExprT", bound=ArrowExpr) @@ -54,8 +51,8 @@ def is_expr(obj: Any) -> TypeIs[Expr]: def evaluate_into_expr( - df: CompliantDataFrame | CompliantLazyFrame, - expr: CompliantExpr[CompliantSeriesT_co], + df: CompliantFrameT_contra, + expr: CompliantExpr[CompliantFrameT_contra, CompliantSeriesT_co], ) -> Sequence[CompliantSeriesT_co]: """Return list of raw columns. @@ -75,7 +72,9 @@ def evaluate_into_expr( def evaluate_into_exprs( - df: CompliantDataFrame, /, *exprs: CompliantExpr[CompliantSeriesT_co] + df: CompliantFrameT_contra, + /, + *exprs: CompliantExpr[CompliantFrameT_contra, CompliantSeriesT_co], ) -> list[CompliantSeriesT_co]: """Evaluate each expr into Series.""" return [ @@ -87,7 +86,8 @@ def evaluate_into_exprs( @overload def maybe_evaluate_expr( - df: CompliantDataFrame, expr: CompliantExpr[CompliantSeriesT_co] + df: CompliantFrameT_contra, + expr: CompliantExpr[CompliantFrameT_contra, CompliantSeriesT_co], ) -> CompliantSeriesT_co: ... @@ -96,7 +96,7 @@ def maybe_evaluate_expr(df: CompliantDataFrame, expr: T) -> T: ... def maybe_evaluate_expr( - df: CompliantDataFrame, expr: CompliantExpr[CompliantSeriesT_co] | T + df: Any, expr: CompliantExpr[Any, CompliantSeriesT_co] | T ) -> CompliantSeriesT_co | T: """Evaluate `expr` if it's an expression, otherwise return it as is.""" if is_compliant_expr(expr): @@ -234,7 +234,7 @@ def reuse_series_namespace_implementation( ) -def is_simple_aggregation(expr: CompliantExpr[Any]) -> bool: +def is_simple_aggregation(expr: CompliantExpr[Any, Any]) -> bool: """Check if expr is a very simple one. Examples: @@ -252,24 +252,22 @@ def is_simple_aggregation(expr: CompliantExpr[Any]) -> bool: def combine_evaluate_output_names( - *exprs: CompliantExpr[Any], -) -> Callable[[CompliantDataFrame | CompliantLazyFrame], Sequence[str]]: + *exprs: CompliantExpr[CompliantFrameT_contra, Any], +) -> Callable[[CompliantFrameT_contra], Sequence[str]]: # Follow left-hand-rule for naming. E.g. `nw.sum_horizontal(expr1, expr2)` takes the # first name of `expr1`. if not is_compliant_expr(exprs[0]): # pragma: no cover msg = f"Safety assertion failed, expected expression, got: {type(exprs[0])}. Please report a bug." raise AssertionError(msg) - def evaluate_output_names( - df: CompliantDataFrame | CompliantLazyFrame, - ) -> Sequence[str]: + def evaluate_output_names(df: CompliantFrameT_contra) -> Sequence[str]: return exprs[0]._evaluate_output_names(df)[:1] return evaluate_output_names def combine_alias_output_names( - *exprs: CompliantExpr[Any], + *exprs: CompliantExpr[Any, Any], ) -> Callable[[Sequence[str]], Sequence[str]] | None: # Follow left-hand-rule for naming. E.g. `nw.sum_horizontal(expr1.alias(alias), expr2)` takes the # aliasing function of `expr1` and apply it to the first output name of `expr1`. @@ -283,11 +281,11 @@ def alias_output_names(names: Sequence[str]) -> Sequence[str]: def extract_compliant( - plx: CompliantNamespace[CompliantSeriesT_co], + plx: CompliantNamespace[CompliantFrameT_contra, CompliantSeriesT_co], other: Any, *, str_as_lit: bool, -) -> CompliantExpr[CompliantSeriesT_co] | object: +) -> CompliantExpr[CompliantFrameT_contra, CompliantSeriesT_co] | object: if is_expr(other): return other._to_compliant_expr(plx) if isinstance(other, str) and not str_as_lit: @@ -301,7 +299,7 @@ def extract_compliant( def evaluate_output_names_and_aliases( - expr: CompliantExpr[Any], + expr: CompliantExpr[Any, Any], df: CompliantDataFrame | CompliantLazyFrame, exclude: Sequence[str], ) -> tuple[Sequence[str], Sequence[str]]: @@ -446,7 +444,7 @@ def apply_n_ary_operation( function: Any, *comparands: IntoExpr, str_as_lit: bool, -) -> CompliantExpr[Any]: +) -> CompliantExpr[Any, Any]: compliant_exprs = ( extract_compliant(plx, comparand, str_as_lit=str_as_lit) for comparand in comparands @@ -463,3 +461,4 @@ def apply_n_ary_operation( for compliant_expr, kind in zip(compliant_exprs, kinds) ) return function(*compliant_exprs) + return function(*compliant_exprs) diff --git a/narwhals/_pandas_like/expr.py b/narwhals/_pandas_like/expr.py index f9694c0777..1bf5ffd592 100644 --- a/narwhals/_pandas_like/expr.py +++ b/narwhals/_pandas_like/expr.py @@ -46,7 +46,7 @@ } -class PandasLikeExpr(CompliantExpr[PandasLikeSeries]): +class PandasLikeExpr(CompliantExpr["PandasLikeDataFrame", PandasLikeSeries]): def __init__( self: Self, call: Callable[[PandasLikeDataFrame], Sequence[PandasLikeSeries]], diff --git a/narwhals/_pandas_like/namespace.py b/narwhals/_pandas_like/namespace.py index d930630e38..c5275d2d4a 100644 --- a/narwhals/_pandas_like/namespace.py +++ b/narwhals/_pandas_like/namespace.py @@ -33,7 +33,7 @@ from narwhals.utils import Version -class PandasLikeNamespace(CompliantNamespace[PandasLikeSeries]): +class PandasLikeNamespace(CompliantNamespace[PandasLikeDataFrame, PandasLikeSeries]): @property def selectors(self: Self) -> PandasSelectorNamespace: return PandasSelectorNamespace(self) diff --git a/narwhals/_spark_like/expr.py b/narwhals/_spark_like/expr.py index 97c176061f..046755de95 100644 --- a/narwhals/_spark_like/expr.py +++ b/narwhals/_spark_like/expr.py @@ -27,7 +27,7 @@ from narwhals.utils import Version -class SparkLikeExpr(CompliantExpr["Column"]): +class SparkLikeExpr(CompliantExpr["SparkLikeLazyFrame", "Column"]): _depth = 0 # Unused, just for compatibility with CompliantExpr def __init__( diff --git a/narwhals/_spark_like/namespace.py b/narwhals/_spark_like/namespace.py index b6f51a60f8..6946e1ff85 100644 --- a/narwhals/_spark_like/namespace.py +++ b/narwhals/_spark_like/namespace.py @@ -28,7 +28,7 @@ from narwhals.utils import Version -class SparkLikeNamespace(CompliantNamespace["Column"]): +class SparkLikeNamespace(CompliantNamespace["SparkLikeLazyFrame", "Column"]): def __init__( self: Self, *, diff --git a/narwhals/dataframe.py b/narwhals/dataframe.py index 90a18935a8..c22125b679 100644 --- a/narwhals/dataframe.py +++ b/narwhals/dataframe.py @@ -82,7 +82,7 @@ def _from_compliant_dataframe(self: Self, df: Any) -> Self: def _flatten_and_extract( self, *exprs: IntoExpr | Iterable[IntoExpr], **named_exprs: IntoExpr - ) -> tuple[list[IntoCompliantExpr[Any]], list[ExprKind]]: + ) -> tuple[list[IntoCompliantExpr[Any, Any]], list[ExprKind]]: """Process `args` and `kwargs`, extracting underlying objects as we go, interpreting strings as column names.""" out_exprs = [] out_kinds = [] diff --git a/narwhals/expr.py b/narwhals/expr.py index 2c323329bc..d628456aa8 100644 --- a/narwhals/expr.py +++ b/narwhals/expr.py @@ -1189,10 +1189,10 @@ def is_between( """ def func( - compliant_expr: CompliantExpr[Any], - lb: CompliantExpr[Any], - ub: CompliantExpr[Any], - ) -> CompliantExpr[Any]: + compliant_expr: CompliantExpr[Any, Any], + lb: CompliantExpr[Any, Any], + ub: CompliantExpr[Any, Any], + ) -> CompliantExpr[Any, Any]: if closed == "left": return (compliant_expr >= lb) & (compliant_expr < ub) # type: ignore[no-any-return] elif closed == "right": diff --git a/narwhals/functions.py b/narwhals/functions.py index 22634f9847..2cd172ae0c 100644 --- a/narwhals/functions.py +++ b/narwhals/functions.py @@ -1458,7 +1458,7 @@ class Then(Expr): def otherwise(self: Self, value: IntoExpr | Any) -> Expr: kind = infer_kind(value, str_as_lit=False) - def func(plx: CompliantNamespace[Any]) -> CompliantExpr[Any]: + def func(plx: CompliantNamespace[Any, Any]) -> CompliantExpr[Any, Any]: compliant_expr = self._to_compliant_expr(plx) compliant_value = extract_compliant(plx, value, str_as_lit=False) if ( diff --git a/narwhals/typing.py b/narwhals/typing.py index 4923bbed7b..98457fab71 100644 --- a/narwhals/typing.py +++ b/narwhals/typing.py @@ -72,25 +72,30 @@ def aggregate(self, *exprs: Any) -> Self: # (so, no broadcasting is necessary). +CompliantFrameT_contra = TypeVar( + "CompliantFrameT_contra", + bound="CompliantDataFrame | CompliantLazyFrame", + contravariant=True, +) CompliantSeriesT_co = TypeVar( "CompliantSeriesT_co", bound=CompliantSeries, covariant=True ) -class CompliantExpr(Protocol, Generic[CompliantSeriesT_co]): +class CompliantExpr(Protocol, Generic[CompliantFrameT_contra, CompliantSeriesT_co]): _implementation: Implementation _backend_version: tuple[int, ...] _version: Version - _evaluate_output_names: Callable[ - [CompliantDataFrame | CompliantLazyFrame], Sequence[str] - ] + _evaluate_output_names: Callable[[CompliantFrameT_contra], Sequence[str]] _alias_output_names: Callable[[Sequence[str]], Sequence[str]] | None _depth: int _function_name: str def __call__(self, df: Any) -> Sequence[CompliantSeriesT_co]: ... def __narwhals_expr__(self) -> None: ... - def __narwhals_namespace__(self) -> CompliantNamespace[CompliantSeriesT_co]: ... + def __narwhals_namespace__( + self, + ) -> CompliantNamespace[CompliantFrameT_contra, CompliantSeriesT_co]: ... def is_null(self) -> Self: ... def alias(self, name: str) -> Self: ... def cast(self, dtype: DType) -> Self: ... @@ -112,11 +117,13 @@ def broadcast( ) -> Self: ... -class CompliantNamespace(Protocol, Generic[CompliantSeriesT_co]): - def col(self, *column_names: str) -> CompliantExpr[CompliantSeriesT_co]: ... +class CompliantNamespace(Protocol, Generic[CompliantFrameT_contra, CompliantSeriesT_co]): + def col( + self, *column_names: str + ) -> CompliantExpr[CompliantFrameT_contra, CompliantSeriesT_co]: ... def lit( self, value: Any, dtype: DType | None - ) -> CompliantExpr[CompliantSeriesT_co]: ... + ) -> CompliantExpr[CompliantFrameT_contra, CompliantSeriesT_co]: ... class SupportsNativeNamespace(Protocol): @@ -316,7 +323,7 @@ class DTypes: # This one needs to be in TYPE_CHECKING to pass on 3.9, # and can only be defined after CompliantExpr has been defined IntoCompliantExpr: TypeAlias = ( - CompliantExpr[CompliantSeriesT_co] | CompliantSeriesT_co + CompliantExpr[CompliantFrameT_contra, CompliantSeriesT_co] | CompliantSeriesT_co ) diff --git a/narwhals/utils.py b/narwhals/utils.py index cb33a603d6..4b3183bf7e 100644 --- a/narwhals/utils.py +++ b/narwhals/utils.py @@ -53,6 +53,7 @@ from narwhals.series import Series from narwhals.typing import CompliantDataFrame from narwhals.typing import CompliantExpr + from narwhals.typing import CompliantFrameT_contra from narwhals.typing import CompliantLazyFrame from narwhals.typing import CompliantSeries from narwhals.typing import CompliantSeriesT_co @@ -1321,8 +1322,8 @@ def is_compliant_series(obj: Any) -> TypeIs[CompliantSeries]: def is_compliant_expr( - obj: CompliantExpr[CompliantSeriesT_co] | Any, -) -> TypeIs[CompliantExpr[CompliantSeriesT_co]]: + obj: CompliantExpr[CompliantFrameT_contra, CompliantSeriesT_co] | Any, +) -> TypeIs[CompliantExpr[CompliantFrameT_contra, CompliantSeriesT_co]]: return hasattr(obj, "__narwhals_expr__")