Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion narwhals/_arrow/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from narwhals.utils import Version


class ArrowExpr(CompliantExpr[ArrowSeries]):
class ArrowExpr(CompliantExpr["ArrowDataFrame", ArrowSeries]):
_implementation: Implementation = Implementation.PYARROW

def __init__(
Expand Down
2 changes: 1 addition & 1 deletion narwhals/_arrow/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from narwhals.utils import Version


class ArrowNamespace(CompliantNamespace[ArrowSeries]):
class ArrowNamespace(CompliantNamespace[ArrowDataFrame, ArrowSeries]):
def _create_expr_from_callable(
self: Self,
func: Callable[[ArrowDataFrame], Sequence[ArrowSeries]],
Expand Down
2 changes: 1 addition & 1 deletion narwhals/_dask/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from narwhals.utils import Version


class DaskExpr(CompliantExpr["dx.Series"]):
class DaskExpr(CompliantExpr["DaskLazyFrame", "dx.Series"]):
_implementation: Implementation = Implementation.DASK

def __init__(
Expand Down
2 changes: 1 addition & 1 deletion narwhals/_dask/group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def _from_native_frame(self: Self, df: DaskLazyFrame) -> DaskLazyFrame:
def agg_dask(
df: DaskLazyFrame,
grouped: Any,
exprs: Sequence[CompliantExpr[dx.Series]],
exprs: Sequence[CompliantExpr[DaskLazyFrame, dx.Series]],
keys: list[str],
from_dataframe: Callable[[Any], DaskLazyFrame],
) -> DaskLazyFrame:
Expand Down
2 changes: 1 addition & 1 deletion narwhals/_dask/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
import dask_expr as dx


class DaskNamespace(CompliantNamespace["dx.Series"]):
class DaskNamespace(CompliantNamespace[DaskLazyFrame, "dx.Series"]):
@property
def selectors(self: Self) -> DaskSelectorNamespace:
return DaskSelectorNamespace(self)
Expand Down
2 changes: 1 addition & 1 deletion narwhals/_duckdb/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from narwhals.utils import Version


class DuckDBExpr(CompliantExpr["duckdb.Expression"]): # type: ignore[type-var]
class DuckDBExpr(CompliantExpr["DuckDBLazyFrame", "duckdb.Expression"]): # type: ignore[type-var]
_implementation = Implementation.DUCKDB
_depth = 0 # Unused, just for compatibility with CompliantExpr

Expand Down
2 changes: 1 addition & 1 deletion narwhals/_duckdb/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from narwhals.utils import Version


class DuckDBNamespace(CompliantNamespace["duckdb.Expression"]): # type: ignore[type-var]
class DuckDBNamespace(CompliantNamespace["DuckDBLazyFrame", "duckdb.Expression"]): # type: ignore[type-var]
def __init__(
self: Self, *, backend_version: tuple[int, ...], version: Version
) -> None:
Expand Down
39 changes: 19 additions & 20 deletions narwhals/_expression_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from typing import Sequence
from typing import TypedDict
from typing import TypeVar
from typing import Union
from typing import overload

from narwhals.dependencies import is_narwhals_series
Expand All @@ -30,16 +29,14 @@
from narwhals.expr import Expr
from narwhals.typing import CompliantDataFrame
from narwhals.typing import CompliantExpr
from narwhals.typing import CompliantFrameT_contra
from narwhals.typing import CompliantLazyFrame
from narwhals.typing import CompliantNamespace
from narwhals.typing import CompliantSeries
from narwhals.typing import CompliantSeriesT_co
from narwhals.typing import IntoExpr
from narwhals.typing import _1DArray

ArrowOrPandasLikeExpr = TypeVar(
"ArrowOrPandasLikeExpr", bound=Union[ArrowExpr, PandasLikeExpr]
)
Comment on lines -40 to -42
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was playing around with fixing the ignores - noticed this was unused

PandasLikeExprT = TypeVar("PandasLikeExprT", bound=PandasLikeExpr)
ArrowExprT = TypeVar("ArrowExprT", bound=ArrowExpr)

Expand All @@ -54,8 +51,8 @@ def is_expr(obj: Any) -> TypeIs[Expr]:


def evaluate_into_expr(
df: CompliantDataFrame | CompliantLazyFrame,
expr: CompliantExpr[CompliantSeriesT_co],
df: CompliantFrameT_contra,
expr: CompliantExpr[CompliantFrameT_contra, CompliantSeriesT_co],
) -> Sequence[CompliantSeriesT_co]:
"""Return list of raw columns.

Expand All @@ -75,7 +72,9 @@ def evaluate_into_expr(


def evaluate_into_exprs(
df: CompliantDataFrame, /, *exprs: CompliantExpr[CompliantSeriesT_co]
df: CompliantFrameT_contra,
/,
*exprs: CompliantExpr[CompliantFrameT_contra, CompliantSeriesT_co],
) -> list[CompliantSeriesT_co]:
"""Evaluate each expr into Series."""
return [
Expand All @@ -87,7 +86,8 @@ def evaluate_into_exprs(

@overload
def maybe_evaluate_expr(
df: CompliantDataFrame, expr: CompliantExpr[CompliantSeriesT_co]
df: CompliantFrameT_contra,
expr: CompliantExpr[CompliantFrameT_contra, CompliantSeriesT_co],
) -> CompliantSeriesT_co: ...


Expand All @@ -96,7 +96,7 @@ def maybe_evaluate_expr(df: CompliantDataFrame, expr: T) -> T: ...


def maybe_evaluate_expr(
df: CompliantDataFrame, expr: CompliantExpr[CompliantSeriesT_co] | T
df: Any, expr: CompliantExpr[Any, CompliantSeriesT_co] | T
) -> CompliantSeriesT_co | T:
"""Evaluate `expr` if it's an expression, otherwise return it as is."""
if is_compliant_expr(expr):
Expand Down Expand Up @@ -234,7 +234,7 @@ def reuse_series_namespace_implementation(
)


def is_simple_aggregation(expr: CompliantExpr[Any]) -> bool:
def is_simple_aggregation(expr: CompliantExpr[Any, Any]) -> bool:
"""Check if expr is a very simple one.

Examples:
Expand All @@ -252,24 +252,22 @@ def is_simple_aggregation(expr: CompliantExpr[Any]) -> bool:


def combine_evaluate_output_names(
*exprs: CompliantExpr[Any],
) -> Callable[[CompliantDataFrame | CompliantLazyFrame], Sequence[str]]:
*exprs: CompliantExpr[CompliantFrameT_contra, Any],
) -> Callable[[CompliantFrameT_contra], Sequence[str]]:
# Follow left-hand-rule for naming. E.g. `nw.sum_horizontal(expr1, expr2)` takes the
# first name of `expr1`.
if not is_compliant_expr(exprs[0]): # pragma: no cover
msg = f"Safety assertion failed, expected expression, got: {type(exprs[0])}. Please report a bug."
raise AssertionError(msg)

def evaluate_output_names(
df: CompliantDataFrame | CompliantLazyFrame,
) -> Sequence[str]:
def evaluate_output_names(df: CompliantFrameT_contra) -> Sequence[str]:
return exprs[0]._evaluate_output_names(df)[:1]

return evaluate_output_names


def combine_alias_output_names(
*exprs: CompliantExpr[Any],
*exprs: CompliantExpr[Any, Any],
) -> Callable[[Sequence[str]], Sequence[str]] | None:
# Follow left-hand-rule for naming. E.g. `nw.sum_horizontal(expr1.alias(alias), expr2)` takes the
# aliasing function of `expr1` and apply it to the first output name of `expr1`.
Expand All @@ -283,11 +281,11 @@ def alias_output_names(names: Sequence[str]) -> Sequence[str]:


def extract_compliant(
plx: CompliantNamespace[CompliantSeriesT_co],
plx: CompliantNamespace[CompliantFrameT_contra, CompliantSeriesT_co],
other: Any,
*,
str_as_lit: bool,
) -> CompliantExpr[CompliantSeriesT_co] | object:
) -> CompliantExpr[CompliantFrameT_contra, CompliantSeriesT_co] | object:
if is_expr(other):
return other._to_compliant_expr(plx)
if isinstance(other, str) and not str_as_lit:
Expand All @@ -301,7 +299,7 @@ def extract_compliant(


def evaluate_output_names_and_aliases(
expr: CompliantExpr[Any],
expr: CompliantExpr[Any, Any],
df: CompliantDataFrame | CompliantLazyFrame,
exclude: Sequence[str],
) -> tuple[Sequence[str], Sequence[str]]:
Expand Down Expand Up @@ -446,7 +444,7 @@ def apply_n_ary_operation(
function: Any,
*comparands: IntoExpr,
str_as_lit: bool,
) -> CompliantExpr[Any]:
) -> CompliantExpr[Any, Any]:
compliant_exprs = (
extract_compliant(plx, comparand, str_as_lit=str_as_lit)
for comparand in comparands
Expand All @@ -463,3 +461,4 @@ def apply_n_ary_operation(
for compliant_expr, kind in zip(compliant_exprs, kinds)
)
return function(*compliant_exprs)
return function(*compliant_exprs)
2 changes: 1 addition & 1 deletion narwhals/_pandas_like/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
}


class PandasLikeExpr(CompliantExpr[PandasLikeSeries]):
class PandasLikeExpr(CompliantExpr["PandasLikeDataFrame", PandasLikeSeries]):
def __init__(
self: Self,
call: Callable[[PandasLikeDataFrame], Sequence[PandasLikeSeries]],
Expand Down
2 changes: 1 addition & 1 deletion narwhals/_pandas_like/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from narwhals.utils import Version


class PandasLikeNamespace(CompliantNamespace[PandasLikeSeries]):
class PandasLikeNamespace(CompliantNamespace[PandasLikeDataFrame, PandasLikeSeries]):
@property
def selectors(self: Self) -> PandasSelectorNamespace:
return PandasSelectorNamespace(self)
Expand Down
2 changes: 1 addition & 1 deletion narwhals/_spark_like/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from narwhals.utils import Version


class SparkLikeExpr(CompliantExpr["Column"]):
class SparkLikeExpr(CompliantExpr["SparkLikeLazyFrame", "Column"]):
_depth = 0 # Unused, just for compatibility with CompliantExpr

def __init__(
Expand Down
2 changes: 1 addition & 1 deletion narwhals/_spark_like/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from narwhals.utils import Version


class SparkLikeNamespace(CompliantNamespace["Column"]):
class SparkLikeNamespace(CompliantNamespace["SparkLikeLazyFrame", "Column"]):
def __init__(
self: Self,
*,
Expand Down
2 changes: 1 addition & 1 deletion narwhals/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def _from_compliant_dataframe(self: Self, df: Any) -> Self:

def _flatten_and_extract(
self, *exprs: IntoExpr | Iterable[IntoExpr], **named_exprs: IntoExpr
) -> tuple[list[IntoCompliantExpr[Any]], list[ExprKind]]:
) -> tuple[list[IntoCompliantExpr[Any, Any]], list[ExprKind]]:
"""Process `args` and `kwargs`, extracting underlying objects as we go, interpreting strings as column names."""
out_exprs = []
out_kinds = []
Expand Down
8 changes: 4 additions & 4 deletions narwhals/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1189,10 +1189,10 @@ def is_between(
"""

def func(
compliant_expr: CompliantExpr[Any],
lb: CompliantExpr[Any],
ub: CompliantExpr[Any],
) -> CompliantExpr[Any]:
compliant_expr: CompliantExpr[Any, Any],
lb: CompliantExpr[Any, Any],
ub: CompliantExpr[Any, Any],
) -> CompliantExpr[Any, Any]:
if closed == "left":
return (compliant_expr >= lb) & (compliant_expr < ub) # type: ignore[no-any-return]
elif closed == "right":
Expand Down
2 changes: 1 addition & 1 deletion narwhals/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1458,7 +1458,7 @@ class Then(Expr):
def otherwise(self: Self, value: IntoExpr | Any) -> Expr:
kind = infer_kind(value, str_as_lit=False)

def func(plx: CompliantNamespace[Any]) -> CompliantExpr[Any]:
def func(plx: CompliantNamespace[Any, Any]) -> CompliantExpr[Any, Any]:
compliant_expr = self._to_compliant_expr(plx)
compliant_value = extract_compliant(plx, value, str_as_lit=False)
if (
Expand Down
25 changes: 16 additions & 9 deletions narwhals/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,25 +72,30 @@ def aggregate(self, *exprs: Any) -> Self:
# (so, no broadcasting is necessary).


CompliantFrameT_contra = TypeVar(
"CompliantFrameT_contra",
bound="CompliantDataFrame | CompliantLazyFrame",
contravariant=True,
)
CompliantSeriesT_co = TypeVar(
"CompliantSeriesT_co", bound=CompliantSeries, covariant=True
)


class CompliantExpr(Protocol, Generic[CompliantSeriesT_co]):
class CompliantExpr(Protocol, Generic[CompliantFrameT_contra, CompliantSeriesT_co]):
_implementation: Implementation
_backend_version: tuple[int, ...]
_version: Version
_evaluate_output_names: Callable[
[CompliantDataFrame | CompliantLazyFrame], Sequence[str]
]
_evaluate_output_names: Callable[[CompliantFrameT_contra], Sequence[str]]
_alias_output_names: Callable[[Sequence[str]], Sequence[str]] | None
_depth: int
_function_name: str

def __call__(self, df: Any) -> Sequence[CompliantSeriesT_co]: ...
def __narwhals_expr__(self) -> None: ...
def __narwhals_namespace__(self) -> CompliantNamespace[CompliantSeriesT_co]: ...
def __narwhals_namespace__(
self,
) -> CompliantNamespace[CompliantFrameT_contra, CompliantSeriesT_co]: ...
def is_null(self) -> Self: ...
def alias(self, name: str) -> Self: ...
def cast(self, dtype: DType) -> Self: ...
Expand All @@ -112,11 +117,13 @@ def broadcast(
) -> Self: ...


class CompliantNamespace(Protocol, Generic[CompliantSeriesT_co]):
def col(self, *column_names: str) -> CompliantExpr[CompliantSeriesT_co]: ...
class CompliantNamespace(Protocol, Generic[CompliantFrameT_contra, CompliantSeriesT_co]):
def col(
self, *column_names: str
) -> CompliantExpr[CompliantFrameT_contra, CompliantSeriesT_co]: ...
def lit(
self, value: Any, dtype: DType | None
) -> CompliantExpr[CompliantSeriesT_co]: ...
) -> CompliantExpr[CompliantFrameT_contra, CompliantSeriesT_co]: ...


class SupportsNativeNamespace(Protocol):
Expand Down Expand Up @@ -316,7 +323,7 @@ class DTypes:
# This one needs to be in TYPE_CHECKING to pass on 3.9,
# and can only be defined after CompliantExpr has been defined
IntoCompliantExpr: TypeAlias = (
CompliantExpr[CompliantSeriesT_co] | CompliantSeriesT_co
CompliantExpr[CompliantFrameT_contra, CompliantSeriesT_co] | CompliantSeriesT_co
)


Expand Down
5 changes: 3 additions & 2 deletions narwhals/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
from narwhals.series import Series
from narwhals.typing import CompliantDataFrame
from narwhals.typing import CompliantExpr
from narwhals.typing import CompliantFrameT_contra
from narwhals.typing import CompliantLazyFrame
from narwhals.typing import CompliantSeries
from narwhals.typing import CompliantSeriesT_co
Expand Down Expand Up @@ -1321,8 +1322,8 @@ def is_compliant_series(obj: Any) -> TypeIs[CompliantSeries]:


def is_compliant_expr(
obj: CompliantExpr[CompliantSeriesT_co] | Any,
) -> TypeIs[CompliantExpr[CompliantSeriesT_co]]:
obj: CompliantExpr[CompliantFrameT_contra, CompliantSeriesT_co] | Any,
) -> TypeIs[CompliantExpr[CompliantFrameT_contra, CompliantSeriesT_co]]:
return hasattr(obj, "__narwhals_expr__")


Expand Down
Loading