diff --git a/docs/api-reference/typing.md b/docs/api-reference/typing.md index 4097b6378d..b7637ee9ae 100644 --- a/docs/api-reference/typing.md +++ b/docs/api-reference/typing.md @@ -19,6 +19,16 @@ Narwhals comes fully statically typed. In addition to `nw.DataFrame`, `nw.Expr`, - IntoSeriesT - SizeUnit - TimeUnit + - AsofJoinStrategy + - ClosedInterval + - ConcatMethod + - FillNullStrategy + - JoinStrategy + - PivotAgg + - RankMethod + - RollingInterpolationMethod + - UniqueKeepStrategy + - LazyUniqueKeepStrategy show_source: false show_bases: false diff --git a/narwhals/_arrow/dataframe.py b/narwhals/_arrow/dataframe.py index 0f1ec8e31d..0c9d92e7ed 100644 --- a/narwhals/_arrow/dataframe.py +++ b/narwhals/_arrow/dataframe.py @@ -59,7 +59,9 @@ from narwhals.schema import Schema from narwhals.typing import CompliantDataFrame from narwhals.typing import CompliantLazyFrame + from narwhals.typing import JoinStrategy from narwhals.typing import SizeUnit + from narwhals.typing import UniqueKeepStrategy from narwhals.typing import _1DArray from narwhals.typing import _2DArray from narwhals.utils import Version @@ -451,7 +453,7 @@ def join( self: Self, other: Self, *, - how: Literal["inner", "left", "full", "cross", "semi", "anti"], + how: JoinStrategy, left_on: Sequence[str] | None, right_on: Sequence[str] | None, suffix: str, @@ -758,7 +760,7 @@ def unique( self: ArrowDataFrame, subset: Sequence[str] | None, *, - keep: Literal["any", "first", "last", "none"], + keep: UniqueKeepStrategy, maintain_order: bool | None = None, ) -> ArrowDataFrame: # The param `maintain_order` is only here for compatibility with the Polars API diff --git a/narwhals/_arrow/expr.py b/narwhals/_arrow/expr.py index b97b50ac0f..0098064451 100644 --- a/narwhals/_arrow/expr.py +++ b/narwhals/_arrow/expr.py @@ -2,7 +2,6 @@ from typing import TYPE_CHECKING from typing import Any -from typing import Literal from typing import Sequence import pyarrow.compute as pc @@ -25,6 +24,7 @@ from narwhals._compliant.typing import EvalNames from narwhals._compliant.typing import EvalSeries from narwhals._expression_parsing import ExprMetadata + from narwhals.typing import RankMethod from narwhals.utils import Version from narwhals.utils import _FullContext @@ -208,12 +208,7 @@ def cum_max(self: Self, *, reverse: bool) -> Self: def cum_prod(self: Self, *, reverse: bool) -> Self: return self._reuse_series("cum_prod", reverse=reverse) - def rank( - self: Self, - method: Literal["average", "min", "max", "dense", "ordinal"], - *, - descending: bool, - ) -> Self: + def rank(self, method: RankMethod, *, descending: bool) -> Self: return self._reuse_series("rank", method=method, descending=descending) ewm_mean = not_implemented() diff --git a/narwhals/_arrow/namespace.py b/narwhals/_arrow/namespace.py index 58e001ec81..134711fb41 100644 --- a/narwhals/_arrow/namespace.py +++ b/narwhals/_arrow/namespace.py @@ -34,6 +34,7 @@ from narwhals._arrow.typing import ArrowChunkedArray from narwhals._arrow.typing import Incomplete from narwhals.dtypes import DType + from narwhals.typing import ConcatMethod from narwhals.utils import Version @@ -211,10 +212,7 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]: ) def concat( - self: Self, - items: Iterable[ArrowDataFrame], - *, - how: Literal["horizontal", "vertical", "diagonal"], + self, items: Iterable[ArrowDataFrame], *, how: ConcatMethod ) -> ArrowDataFrame: dfs = [item.native for item in items] diff --git a/narwhals/_arrow/series.py b/narwhals/_arrow/series.py index 5230cdb410..02510f4a01 100644 --- a/narwhals/_arrow/series.py +++ b/narwhals/_arrow/series.py @@ -4,7 +4,6 @@ from typing import Any from typing import Iterable from typing import Iterator -from typing import Literal from typing import Mapping from typing import Sequence from typing import cast @@ -56,7 +55,11 @@ from narwhals._arrow.typing import _AsPyType from narwhals._arrow.typing import _BasicDataType from narwhals.dtypes import DType + from narwhals.typing import ClosedInterval + from narwhals.typing import FillNullStrategy from narwhals.typing import Into1DArray + from narwhals.typing import RankMethod + from narwhals.typing import RollingInterpolationMethod from narwhals.typing import _1DArray from narwhals.typing import _2DArray from narwhals.utils import Version @@ -499,10 +502,7 @@ def all(self: Self, *, _return_py_scalar: bool = True) -> bool: ) def is_between( - self: Self, - lower_bound: Any, - upper_bound: Any, - closed: Literal["left", "right", "none", "both"], + self, lower_bound: Any, upper_bound: Any, closed: ClosedInterval ) -> Self: _, lower_bound = extract_native(self, lower_bound) _, upper_bound = extract_native(self, upper_bound) @@ -636,17 +636,14 @@ def sample( return self._with_native(self.native.take(mask)) def fill_null( - self: Self, - value: Any | None, - strategy: Literal["forward", "backward"] | None, - limit: int | None, + self, value: Any | None, strategy: FillNullStrategy | None, limit: int | None ) -> Self: import numpy as np # ignore-banned-import def fill_aux( arr: ArrowArray | ArrowChunkedArray, limit: int, - direction: Literal["forward", "backward"] | None = None, + direction: FillNullStrategy | None = None, ) -> ArrowArray: # this algorithm first finds the indices of the valid values to fill all the null value positions # then it calculates the distance of each new index and the original index @@ -812,9 +809,9 @@ def to_dummies(self: Self, *, separator: str, drop_first: bool) -> ArrowDataFram ).simple_select(*output_order) def quantile( - self: Self, + self, quantile: float, - interpolation: Literal["nearest", "higher", "lower", "midpoint", "linear"], + interpolation: RollingInterpolationMethod, *, _return_py_scalar: bool = True, ) -> float: @@ -1028,12 +1025,7 @@ def rolling_std( ** 0.5 ) - def rank( - self: Self, - method: Literal["average", "min", "max", "dense", "ordinal"], - *, - descending: bool, - ) -> Self: + def rank(self, method: RankMethod, *, descending: bool) -> Self: if method == "average": msg = ( "`rank` with `method='average' is not supported for pyarrow backend. " diff --git a/narwhals/_compliant/dataframe.py b/narwhals/_compliant/dataframe.py index 60d5397ed9..e22cd60e35 100644 --- a/narwhals/_compliant/dataframe.py +++ b/narwhals/_compliant/dataframe.py @@ -40,7 +40,11 @@ from narwhals._translate import IntoArrowTable from narwhals.dtypes import DType from narwhals.schema import Schema + from narwhals.typing import AsofJoinStrategy + from narwhals.typing import JoinStrategy + from narwhals.typing import LazyUniqueKeepStrategy from narwhals.typing import SizeUnit + from narwhals.typing import UniqueKeepStrategy from narwhals.typing import _2DArray from narwhals.utils import Implementation from narwhals.utils import _FullContext @@ -144,7 +148,7 @@ def join( self: Self, other: Self, *, - how: Literal["inner", "left", "full", "cross", "semi", "anti"], + how: JoinStrategy, left_on: Sequence[str] | None, right_on: Sequence[str] | None, suffix: str, @@ -157,7 +161,7 @@ def join_asof( right_on: str | None, by_left: Sequence[str] | None, by_right: Sequence[str] | None, - strategy: Literal["backward", "forward", "nearest"], + strategy: AsofJoinStrategy, suffix: str, ) -> Self: ... def lazy(self, *, backend: Implementation | None) -> CompliantLazyFrame[Any, Any]: ... @@ -193,7 +197,7 @@ def unique( self, subset: Sequence[str] | None, *, - keep: Literal["any", "first", "last", "none"], + keep: UniqueKeepStrategy, maintain_order: bool | None = None, ) -> Self: ... def unpivot( @@ -284,7 +288,7 @@ def join_asof( right_on: str | None, by_left: Sequence[str] | None, by_right: Sequence[str] | None, - strategy: Literal["backward", "forward", "nearest"], + strategy: AsofJoinStrategy, suffix: str, ) -> Self: ... def rename(self, mapping: Mapping[str, str]) -> Self: ... @@ -295,10 +299,7 @@ def sort( @deprecated("`LazyFrame.tail` is deprecated and will be removed in a future version.") def tail(self, n: int) -> Self: ... def unique( - self, - subset: Sequence[str] | None, - *, - keep: Literal["any", "none"], + self, subset: Sequence[str] | None, *, keep: LazyUniqueKeepStrategy ) -> Self: ... def unpivot( self, diff --git a/narwhals/_compliant/expr.py b/narwhals/_compliant/expr.py index b54ce6acd5..9ae4b663aa 100644 --- a/narwhals/_compliant/expr.py +++ b/narwhals/_compliant/expr.py @@ -64,6 +64,9 @@ from narwhals._expression_parsing import ExprKind from narwhals._expression_parsing import ExprMetadata from narwhals.dtypes import DType + from narwhals.typing import FillNullStrategy + from narwhals.typing import RankMethod + from narwhals.typing import RollingInterpolationMethod from narwhals.typing import TimeUnit from narwhals.utils import Implementation from narwhals.utils import Version @@ -131,10 +134,7 @@ def n_unique(self) -> Self: ... def null_count(self) -> Self: ... def drop_nulls(self) -> Self: ... def fill_null( - self, - value: Any | None, - strategy: Literal["forward", "backward"] | None, - limit: int | None, + self, value: Any | None, strategy: FillNullStrategy | None, limit: int | None ) -> Self: ... def diff(self) -> Self: ... def unique(self) -> Self: ... @@ -156,12 +156,7 @@ def cum_max(self, *, reverse: bool) -> Self: ... def cum_prod(self, *, reverse: bool) -> Self: ... def is_in(self, other: Any) -> Self: ... def sort(self, *, descending: bool, nulls_last: bool) -> Self: ... - def rank( - self, - method: Literal["average", "min", "max", "dense", "ordinal"], - *, - descending: bool, - ) -> Self: ... + def rank(self, method: RankMethod, *, descending: bool) -> Self: ... def replace_strict( self, old: Sequence[Any] | Mapping[Any, Any], @@ -181,9 +176,7 @@ def sample( seed: int | None, ) -> Self: ... def quantile( - self, - quantile: float, - interpolation: Literal["nearest", "higher", "lower", "midpoint", "linear"], + self, quantile: float, interpolation: RollingInterpolationMethod ) -> Self: ... def map_batches( self, @@ -657,10 +650,7 @@ def is_nan(self) -> Self: return self._reuse_series("is_nan") def fill_null( - self, - value: Any | None, - strategy: Literal["forward", "backward"] | None, - limit: int | None, + self, value: Any | None, strategy: FillNullStrategy | None, limit: int | None ) -> Self: return self._reuse_series( "fill_null", value=value, strategy=strategy, limit=limit @@ -746,9 +736,7 @@ def is_last_distinct(self) -> Self: return self._reuse_series("is_last_distinct") def quantile( - self, - quantile: float, - interpolation: Literal["nearest", "higher", "lower", "midpoint", "linear"], + self, quantile: float, interpolation: RollingInterpolationMethod ) -> Self: return self._reuse_series( "quantile", diff --git a/narwhals/_compliant/namespace.py b/narwhals/_compliant/namespace.py index a8c10a5fc5..fbfe626f2c 100644 --- a/narwhals/_compliant/namespace.py +++ b/narwhals/_compliant/namespace.py @@ -5,7 +5,6 @@ from typing import Any from typing import Container from typing import Iterable -from typing import Literal from typing import Mapping from typing import Protocol from typing import Sequence @@ -35,6 +34,7 @@ from narwhals._compliant.when_then import EagerWhen from narwhals.dtypes import DType from narwhals.schema import Schema + from narwhals.typing import ConcatMethod from narwhals.typing import Into1DArray from narwhals.typing import _2DArray from narwhals.utils import Implementation @@ -75,10 +75,7 @@ def mean_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ... def min_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ... def max_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ... def concat( - self, - items: Iterable[CompliantFrameT], - *, - how: Literal["horizontal", "vertical", "diagonal"], + self, items: Iterable[CompliantFrameT], *, how: ConcatMethod ) -> CompliantFrameT: ... def when( self, predicate: CompliantExprT diff --git a/narwhals/_compliant/series.py b/narwhals/_compliant/series.py index 81d6ae045e..5ebcb80c7b 100644 --- a/narwhals/_compliant/series.py +++ b/narwhals/_compliant/series.py @@ -5,7 +5,6 @@ from typing import Generic from typing import Iterable from typing import Iterator -from typing import Literal from typing import Mapping from typing import Protocol from typing import Sequence @@ -40,7 +39,11 @@ from narwhals._compliant.namespace import CompliantNamespace from narwhals._compliant.namespace import EagerNamespace from narwhals.dtypes import DType + from narwhals.typing import ClosedInterval + from narwhals.typing import FillNullStrategy from narwhals.typing import Into1DArray + from narwhals.typing import RankMethod + from narwhals.typing import RollingInterpolationMethod from narwhals.typing import _1DArray from narwhals.utils import Implementation from narwhals.utils import Version @@ -152,10 +155,7 @@ def ewm_mean( ignore_nulls: bool, ) -> Self: ... def fill_null( - self, - value: Any | None, - strategy: Literal["forward", "backward"] | None, - limit: int | None, + self, value: Any | None, strategy: FillNullStrategy | None, limit: int | None ) -> Self: ... def filter(self, predicate: Any) -> Self: ... def gather_every(self, n: int, offset: int) -> Self: ... @@ -169,10 +169,7 @@ def hist( ) -> CompliantDataFrame[Self, Any, Any]: ... def head(self, n: int) -> Self: ... def is_between( - self, - lower_bound: Any, - upper_bound: Any, - closed: Literal["left", "right", "none", "both"], + self, lower_bound: Any, upper_bound: Any, closed: ClosedInterval ) -> Self: ... def is_finite(self) -> Self: ... def is_first_distinct(self) -> Self: ... @@ -192,16 +189,9 @@ def mode(self) -> Self: ... def n_unique(self) -> int: ... def null_count(self) -> int: ... def quantile( - self, - quantile: float, - interpolation: Literal["nearest", "higher", "lower", "midpoint", "linear"], + self, quantile: float, interpolation: RollingInterpolationMethod ) -> float: ... - def rank( - self, - method: Literal["average", "min", "max", "dense", "ordinal"], - *, - descending: bool, - ) -> Self: ... + def rank(self, method: RankMethod, *, descending: bool) -> Self: ... def replace_strict( self, old: Sequence[Any] | Mapping[Any, Any], diff --git a/narwhals/_dask/dataframe.py b/narwhals/_dask/dataframe.py index 34029698bc..01af781fd1 100644 --- a/narwhals/_dask/dataframe.py +++ b/narwhals/_dask/dataframe.py @@ -3,7 +3,6 @@ from typing import TYPE_CHECKING from typing import Any from typing import Iterator -from typing import Literal from typing import Mapping from typing import Sequence @@ -37,6 +36,9 @@ from narwhals._dask.group_by import DaskLazyGroupBy from narwhals._dask.namespace import DaskNamespace from narwhals.dtypes import DType + from narwhals.typing import AsofJoinStrategy + from narwhals.typing import JoinStrategy + from narwhals.typing import LazyUniqueKeepStrategy from narwhals.utils import Version from narwhals.utils import _FullContext @@ -215,10 +217,7 @@ def head(self: Self, n: int) -> Self: return self._with_native(self.native.head(n=n, compute=False, npartitions=-1)) def unique( - self: Self, - subset: Sequence[str] | None, - *, - keep: Literal["any", "none"], + self, subset: Sequence[str] | None, *, keep: LazyUniqueKeepStrategy ) -> Self: check_column_exists(self.columns, subset) if keep == "none": @@ -252,7 +251,7 @@ def join( self: Self, other: Self, *, - how: Literal["inner", "left", "full", "cross", "semi", "anti"], + how: JoinStrategy, left_on: Sequence[str] | None, right_on: Sequence[str] | None, suffix: str, @@ -383,7 +382,7 @@ def join_asof( right_on: str | None, by_left: Sequence[str] | None, by_right: Sequence[str] | None, - strategy: Literal["backward", "forward", "nearest"], + strategy: AsofJoinStrategy, suffix: str, ) -> Self: plx = self.__native_namespace__() diff --git a/narwhals/_dask/expr.py b/narwhals/_dask/expr.py index addc27f322..6fa8e7e324 100644 --- a/narwhals/_dask/expr.py +++ b/narwhals/_dask/expr.py @@ -26,6 +26,8 @@ if TYPE_CHECKING: from narwhals._compliant.typing import AliasNames from narwhals._expression_parsing import ExprKind + from narwhals.typing import FillNullStrategy + from narwhals.typing import RollingInterpolationMethod try: import dask.dataframe.dask_expr as dx @@ -478,11 +480,11 @@ def any(self: Self) -> Self: ) def fill_null( - self: Self, + self, value: Self | Any | None, - strategy: Literal["forward", "backward"] | None, + strategy: FillNullStrategy | None, limit: int | None, - ) -> DaskExpr: + ) -> Self: def func(_input: dx.Series) -> dx.Series: if value is not None: res_ser = _input.fillna(value) @@ -537,9 +539,7 @@ def len(self: Self) -> Self: return self._with_callable(lambda _input: _input.size.to_series(), "len") def quantile( - self: Self, - quantile: float, - interpolation: Literal["nearest", "higher", "lower", "midpoint", "linear"], + self, quantile: float, interpolation: RollingInterpolationMethod ) -> Self: if interpolation == "linear": diff --git a/narwhals/_dask/namespace.py b/narwhals/_dask/namespace.py index 88ef002a24..389faed6e0 100644 --- a/narwhals/_dask/namespace.py +++ b/narwhals/_dask/namespace.py @@ -5,7 +5,6 @@ from typing import TYPE_CHECKING from typing import Any from typing import Iterable -from typing import Literal from typing import Sequence import dask.dataframe as dd @@ -31,6 +30,7 @@ from typing_extensions import Self from narwhals.dtypes import DType + from narwhals.typing import ConcatMethod from narwhals.utils import Version try: @@ -151,10 +151,7 @@ def func(df: DaskLazyFrame) -> list[dx.Series]: ) def concat( - self: Self, - items: Iterable[DaskLazyFrame], - *, - how: Literal["horizontal", "vertical", "diagonal"], + self, items: Iterable[DaskLazyFrame], *, how: ConcatMethod ) -> DaskLazyFrame: if not items: msg = "No items to concatenate" # pragma: no cover diff --git a/narwhals/_duckdb/dataframe.py b/narwhals/_duckdb/dataframe.py index 32769157b3..b81be8e153 100644 --- a/narwhals/_duckdb/dataframe.py +++ b/narwhals/_duckdb/dataframe.py @@ -5,7 +5,6 @@ from typing import TYPE_CHECKING from typing import Any from typing import Iterator -from typing import Literal from typing import Mapping from typing import Sequence @@ -44,6 +43,9 @@ from narwhals._duckdb.namespace import DuckDBNamespace from narwhals._duckdb.series import DuckDBInterchangeSeries from narwhals.dtypes import DType + from narwhals.typing import AsofJoinStrategy + from narwhals.typing import JoinStrategy + from narwhals.typing import LazyUniqueKeepStrategy from narwhals.utils import _FullContext from narwhals.typing import CompliantLazyFrame @@ -254,7 +256,7 @@ def join( self: Self, other: Self, *, - how: Literal["inner", "left", "full", "cross", "semi", "anti"], + how: JoinStrategy, left_on: Sequence[str] | None, right_on: Sequence[str] | None, suffix: str, @@ -308,7 +310,7 @@ def join_asof( right_on: str | None, by_left: Sequence[str] | None, by_right: Sequence[str] | None, - strategy: Literal["backward", "forward", "nearest"], + strategy: AsofJoinStrategy, suffix: str, ) -> Self: lhs = self.native @@ -354,7 +356,7 @@ def collect_schema(self: Self) -> dict[str, DType]: } def unique( - self: Self, subset: Sequence[str] | None, *, keep: Literal["any", "none"] + self, subset: Sequence[str] | None, *, keep: LazyUniqueKeepStrategy ) -> Self: if subset_ := subset if keep == "any" else (subset or self.columns): # Sanitise input diff --git a/narwhals/_duckdb/expr.py b/narwhals/_duckdb/expr.py index 75fc2461c7..85a1c92fa7 100644 --- a/narwhals/_duckdb/expr.py +++ b/narwhals/_duckdb/expr.py @@ -42,6 +42,8 @@ from narwhals._duckdb.typing import WindowFunction from narwhals._expression_parsing import ExprMetadata from narwhals.dtypes import DType + from narwhals.typing import RankMethod + from narwhals.typing import RollingInterpolationMethod from narwhals.utils import Version from narwhals.utils import _FullContext @@ -384,9 +386,7 @@ def any(self: Self) -> Self: return self._with_callable(lambda _input: FunctionExpression("bool_or", _input)) def quantile( - self: Self, - quantile: float, - interpolation: Literal["nearest", "higher", "lower", "midpoint", "linear"], + self, quantile: float, interpolation: RollingInterpolationMethod ) -> Self: def func(_input: duckdb.Expression) -> duckdb.Expression: if interpolation == "linear": @@ -690,12 +690,7 @@ def func(_input: duckdb.Expression) -> duckdb.Expression: return self._with_callable(func) - def rank( - self, - method: Literal["average", "min", "max", "dense", "ordinal"], - *, - descending: bool, - ) -> Self: + def rank(self, method: RankMethod, *, descending: bool) -> Self: if method == "min": func_name = "rank" elif method == "dense": diff --git a/narwhals/_duckdb/namespace.py b/narwhals/_duckdb/namespace.py index cb7ec12ef2..b755be3723 100644 --- a/narwhals/_duckdb/namespace.py +++ b/narwhals/_duckdb/namespace.py @@ -6,7 +6,6 @@ from typing import TYPE_CHECKING from typing import Any from typing import Iterable -from typing import Literal from typing import Sequence import duckdb @@ -33,6 +32,7 @@ from typing_extensions import Self from narwhals.dtypes import DType + from narwhals.typing import ConcatMethod from narwhals.utils import Version @@ -60,10 +60,7 @@ def _lazyframe(self) -> type[DuckDBLazyFrame]: return DuckDBLazyFrame def concat( - self: Self, - items: Iterable[DuckDBLazyFrame], - *, - how: Literal["horizontal", "vertical", "diagonal"], + self, items: Iterable[DuckDBLazyFrame], *, how: ConcatMethod ) -> DuckDBLazyFrame: if how == "horizontal": msg = "horizontal concat not supported for duckdb. Please join instead" diff --git a/narwhals/_pandas_like/dataframe.py b/narwhals/_pandas_like/dataframe.py index 82314ceb98..519307681e 100644 --- a/narwhals/_pandas_like/dataframe.py +++ b/narwhals/_pandas_like/dataframe.py @@ -61,10 +61,13 @@ from narwhals._translate import IntoArrowTable from narwhals.dtypes import DType from narwhals.schema import Schema + from narwhals.typing import AsofJoinStrategy from narwhals.typing import CompliantDataFrame from narwhals.typing import CompliantLazyFrame from narwhals.typing import DTypeBackend + from narwhals.typing import JoinStrategy from narwhals.typing import SizeUnit + from narwhals.typing import UniqueKeepStrategy from narwhals.typing import _1DArray from narwhals.typing import _2DArray from narwhals.utils import Version @@ -673,7 +676,7 @@ def join( self: Self, other: Self, *, - how: Literal["inner", "left", "full", "cross", "semi", "anti"], + how: JoinStrategy, left_on: Sequence[str] | None, right_on: Sequence[str] | None, suffix: str, @@ -824,7 +827,7 @@ def join_asof( right_on: str | None, by_left: Sequence[str] | None, by_right: Sequence[str] | None, - strategy: Literal["backward", "forward", "nearest"], + strategy: AsofJoinStrategy, suffix: str, ) -> Self: plx = self.__native_namespace__() @@ -850,10 +853,10 @@ def tail(self: Self, n: int) -> Self: return self._with_native(self.native.tail(n), validate_column_names=False) def unique( - self: Self, + self, subset: Sequence[str] | None, *, - keep: Literal["any", "first", "last", "none"], + keep: UniqueKeepStrategy, maintain_order: bool | None = None, ) -> Self: # The param `maintain_order` is only here for compatibility with the Polars API diff --git a/narwhals/_pandas_like/expr.py b/narwhals/_pandas_like/expr.py index a2664598d2..0f4584a59d 100644 --- a/narwhals/_pandas_like/expr.py +++ b/narwhals/_pandas_like/expr.py @@ -2,7 +2,6 @@ from typing import TYPE_CHECKING from typing import Any -from typing import Literal from typing import Sequence from narwhals._compliant import EagerExpr @@ -21,6 +20,7 @@ from narwhals._expression_parsing import ExprMetadata from narwhals._pandas_like.dataframe import PandasLikeDataFrame from narwhals._pandas_like.namespace import PandasLikeNamespace + from narwhals.typing import RankMethod from narwhals.utils import Implementation from narwhals.utils import Version from narwhals.utils import _FullContext @@ -365,12 +365,7 @@ def rolling_var( }, ) - def rank( - self: Self, - method: Literal["average", "min", "max", "dense", "ordinal"], - *, - descending: bool, - ) -> Self: + def rank(self, method: RankMethod, *, descending: bool) -> Self: return self._reuse_series( "rank", call_kwargs={"method": method, "descending": descending} ) diff --git a/narwhals/_pandas_like/namespace.py b/narwhals/_pandas_like/namespace.py index 627fc4fee5..9553edc1f9 100644 --- a/narwhals/_pandas_like/namespace.py +++ b/narwhals/_pandas_like/namespace.py @@ -5,7 +5,6 @@ from typing import TYPE_CHECKING from typing import Any from typing import Iterable -from typing import Literal from narwhals._compliant import CompliantThen from narwhals._compliant import EagerNamespace @@ -27,6 +26,7 @@ from typing_extensions import Self from narwhals.dtypes import DType + from narwhals.typing import ConcatMethod from narwhals.utils import Implementation from narwhals.utils import Version @@ -224,10 +224,7 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: ) def concat( - self: Self, - items: Iterable[PandasLikeDataFrame], - *, - how: Literal["horizontal", "vertical", "diagonal"], + self, items: Iterable[PandasLikeDataFrame], *, how: ConcatMethod ) -> PandasLikeDataFrame: dfs: list[Any] = [item._native_frame for item in items] if how == "horizontal": diff --git a/narwhals/_pandas_like/series.py b/narwhals/_pandas_like/series.py index 626a28c3d0..cdd304d275 100644 --- a/narwhals/_pandas_like/series.py +++ b/narwhals/_pandas_like/series.py @@ -4,7 +4,6 @@ from typing import Any from typing import Iterable from typing import Iterator -from typing import Literal from typing import Mapping from typing import Sequence from typing import cast @@ -48,7 +47,11 @@ from narwhals._pandas_like.dataframe import PandasLikeDataFrame from narwhals._pandas_like.namespace import PandasLikeNamespace from narwhals.dtypes import DType + from narwhals.typing import ClosedInterval + from narwhals.typing import FillNullStrategy from narwhals.typing import Into1DArray + from narwhals.typing import RankMethod + from narwhals.typing import RollingInterpolationMethod from narwhals.typing import _1DArray from narwhals.typing import _AnyDArray from narwhals.utils import Version @@ -334,11 +337,8 @@ def to_list(self: Self) -> list[Any]: return self.native.to_arrow().to_pylist() if is_cudf else self.native.to_list() def is_between( - self: Self, - lower_bound: Any, - upper_bound: Any, - closed: Literal["left", "right", "none", "both"], - ) -> PandasLikeSeries: + self, lower_bound: Any, upper_bound: Any, closed: ClosedInterval + ) -> Self: ser = self.native _, lower_bound = align_and_extract_native(self, lower_bound) _, upper_bound = align_and_extract_native(self, upper_bound) @@ -552,10 +552,7 @@ def is_nan(self: Self) -> PandasLikeSeries: raise InvalidOperationError(msg) def fill_null( - self: Self, - value: Any | None, - strategy: Literal["forward", "backward"] | None, - limit: int | None, + self, value: Any | None, strategy: FillNullStrategy | None, limit: int | None ) -> Self: ser = self.native if value is not None: @@ -762,9 +759,7 @@ def value_counts( return PandasLikeDataFrame.from_native(val_count, context=self) def quantile( - self: Self, - quantile: float, - interpolation: Literal["nearest", "higher", "lower", "midpoint", "linear"], + self, quantile: float, interpolation: RollingInterpolationMethod ) -> float: return self.native.quantile(q=quantile, interpolation=interpolation) @@ -921,12 +916,7 @@ def is_finite(self: Self) -> Self: s = self.native return self._with_native((s > float("-inf")) & (s < float("inf"))) - def rank( - self: Self, - method: Literal["average", "min", "max", "dense", "ordinal"], - *, - descending: bool, - ) -> Self: + def rank(self, method: RankMethod, *, descending: bool) -> Self: pd_method = "first" if method == "ordinal" else method name = self.name if ( diff --git a/narwhals/_polars/dataframe.py b/narwhals/_polars/dataframe.py index f1a70fe0e6..a6dae67223 100644 --- a/narwhals/_polars/dataframe.py +++ b/narwhals/_polars/dataframe.py @@ -43,6 +43,7 @@ from narwhals.schema import Schema from narwhals.typing import CompliantDataFrame from narwhals.typing import CompliantLazyFrame + from narwhals.typing import JoinStrategy from narwhals.typing import _2DArray from narwhals.utils import Version from narwhals.utils import _FullContext @@ -446,7 +447,7 @@ def join( self: Self, other: Self, *, - how: Literal["inner", "left", "full", "cross", "semi", "anti"], + how: JoinStrategy, left_on: Sequence[str] | None, right_on: Sequence[str] | None, suffix: str, @@ -672,7 +673,7 @@ def join( self: Self, other: Self, *, - how: Literal["inner", "left", "full", "cross", "semi", "anti"], + how: JoinStrategy, left_on: Sequence[str] | None, right_on: Sequence[str] | None, suffix: str, diff --git a/narwhals/_spark_like/dataframe.py b/narwhals/_spark_like/dataframe.py index 8a1300c7a2..acb7e284b5 100644 --- a/narwhals/_spark_like/dataframe.py +++ b/narwhals/_spark_like/dataframe.py @@ -6,7 +6,6 @@ from typing import TYPE_CHECKING from typing import Any from typing import Iterator -from typing import Literal from typing import Mapping from typing import Sequence @@ -44,6 +43,8 @@ from narwhals._spark_like.group_by import SparkLikeLazyGroupBy from narwhals._spark_like.namespace import SparkLikeNamespace from narwhals.dtypes import DType + from narwhals.typing import JoinStrategy + from narwhals.typing import LazyUniqueKeepStrategy from narwhals.utils import Version from narwhals.utils import _FullContext @@ -331,10 +332,7 @@ def rename(self: Self, mapping: Mapping[str, str]) -> Self: ) def unique( - self: Self, - subset: Sequence[str] | None, - *, - keep: Literal["any", "none"], + self, subset: Sequence[str] | None, *, keep: LazyUniqueKeepStrategy ) -> Self: check_column_exists(self.columns, subset) subset = list(subset) if subset else None @@ -352,7 +350,7 @@ def unique( def join( self: Self, other: Self, - how: Literal["inner", "left", "full", "cross", "semi", "anti"], + how: JoinStrategy, left_on: Sequence[str] | None, right_on: Sequence[str] | None, suffix: str, diff --git a/narwhals/_spark_like/expr.py b/narwhals/_spark_like/expr.py index 968a903d8d..62df2343a4 100644 --- a/narwhals/_spark_like/expr.py +++ b/narwhals/_spark_like/expr.py @@ -37,6 +37,8 @@ from narwhals._spark_like.namespace import SparkLikeNamespace from narwhals._spark_like.typing import WindowFunction from narwhals.dtypes import DType + from narwhals.typing import FillNullStrategy + from narwhals.typing import RankMethod from narwhals.utils import Version from narwhals.utils import _FullContext @@ -695,10 +697,7 @@ def cum_prod(self, *, reverse: bool) -> Self: ) def fill_null( - self, - value: Any | None, - strategy: Literal["forward", "backward"] | None, - limit: int | None, + self, value: Any | None, strategy: FillNullStrategy | None, limit: int | None ) -> Self: if strategy is not None: msg = "Support for strategies is not yet implemented." @@ -755,12 +754,7 @@ def rolling_std( ) ) - def rank( - self, - method: Literal["average", "min", "max", "dense", "ordinal"], - *, - descending: bool, - ) -> Self: + def rank(self, method: RankMethod, *, descending: bool) -> Self: if method == "min": func_name = "rank" elif method == "dense": diff --git a/narwhals/_spark_like/namespace.py b/narwhals/_spark_like/namespace.py index c9bd25e43e..6d97f0f670 100644 --- a/narwhals/_spark_like/namespace.py +++ b/narwhals/_spark_like/namespace.py @@ -4,7 +4,6 @@ from functools import reduce from typing import TYPE_CHECKING from typing import Iterable -from typing import Literal from typing import Sequence from narwhals._compliant import CompliantThen @@ -23,6 +22,7 @@ from narwhals._spark_like.dataframe import SQLFrameDataFrame # noqa: F401 from narwhals.dtypes import DType + from narwhals.typing import ConcatMethod from narwhals.utils import Implementation from narwhals.utils import Version @@ -189,10 +189,7 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: ) def concat( - self: Self, - items: Iterable[SparkLikeLazyFrame], - *, - how: Literal["horizontal", "vertical", "diagonal"], + self, items: Iterable[SparkLikeLazyFrame], *, how: ConcatMethod ) -> SparkLikeLazyFrame: dfs = [item._native_frame for item in items] if how == "horizontal": diff --git a/narwhals/dataframe.py b/narwhals/dataframe.py index f7574a95f6..bd7a14edc8 100644 --- a/narwhals/dataframe.py +++ b/narwhals/dataframe.py @@ -55,10 +55,15 @@ from narwhals.group_by import GroupBy from narwhals.group_by import LazyGroupBy from narwhals.series import Series + from narwhals.typing import AsofJoinStrategy from narwhals.typing import IntoDataFrame from narwhals.typing import IntoExpr from narwhals.typing import IntoFrame + from narwhals.typing import JoinStrategy + from narwhals.typing import LazyUniqueKeepStrategy + from narwhals.typing import PivotAgg from narwhals.typing import SizeUnit + from narwhals.typing import UniqueKeepStrategy from narwhals.typing import _1DArray from narwhals.typing import _2DArray @@ -225,7 +230,7 @@ def join( self: Self, other: Self, on: str | list[str] | None = None, - how: Literal["inner", "left", "full", "cross", "semi", "anti"] = "inner", + how: JoinStrategy = "inner", *, left_on: str | list[str] | None = None, right_on: str | list[str] | None = None, @@ -285,7 +290,7 @@ def join_asof( by_left: str | list[str] | None = None, by_right: str | list[str] | None = None, by: str | list[str] | None = None, - strategy: Literal["backward", "forward", "nearest"] = "backward", + strategy: AsofJoinStrategy = "backward", suffix: str = "_right", ) -> Self: _supported_strategies = ("backward", "forward", "nearest") @@ -1397,7 +1402,7 @@ def unique( self: Self, subset: str | list[str] | None = None, *, - keep: Literal["any", "first", "last", "none"] = "any", + keep: UniqueKeepStrategy = "any", maintain_order: bool = False, ) -> Self: """Drop duplicate rows from this dataframe. @@ -1604,7 +1609,7 @@ def join( self: Self, other: Self, on: str | list[str] | None = None, - how: Literal["inner", "left", "full", "cross", "semi", "anti"] = "inner", + how: JoinStrategy = "inner", *, left_on: str | list[str] | None = None, right_on: str | list[str] | None = None, @@ -1659,7 +1664,7 @@ def join_asof( by_left: str | list[str] | None = None, by_right: str | list[str] | None = None, by: str | list[str] | None = None, - strategy: Literal["backward", "forward", "nearest"] = "backward", + strategy: AsofJoinStrategy = "backward", suffix: str = "_right", ) -> Self: """Perform an asof join. @@ -1891,10 +1896,7 @@ def pivot( *, index: str | list[str] | None = None, values: str | list[str] | None = None, - aggregate_function: Literal[ - "min", "max", "first", "last", "sum", "mean", "median", "len" - ] - | None = None, + aggregate_function: PivotAgg | None = None, maintain_order: bool | None = None, sort_columns: bool = False, separator: str = "_", @@ -2658,7 +2660,7 @@ def unique( self: Self, subset: str | list[str] | None = None, *, - keep: Literal["any", "none"] = "any", + keep: LazyUniqueKeepStrategy = "any", ) -> Self: """Drop duplicate rows from this LazyFrame. @@ -2891,7 +2893,7 @@ def join( self: Self, other: Self, on: str | list[str] | None = None, - how: Literal["inner", "left", "full", "cross", "semi", "anti"] = "inner", + how: JoinStrategy = "inner", *, left_on: str | list[str] | None = None, right_on: str | list[str] | None = None, @@ -2955,7 +2957,7 @@ def join_asof( by_left: str | list[str] | None = None, by_right: str | list[str] | None = None, by: str | list[str] | None = None, - strategy: Literal["backward", "forward", "nearest"] = "backward", + strategy: AsofJoinStrategy = "backward", suffix: str = "_right", ) -> Self: """Perform an asof join. diff --git a/narwhals/expr.py b/narwhals/expr.py index 1e15fe4924..c996a6dd67 100644 --- a/narwhals/expr.py +++ b/narwhals/expr.py @@ -4,7 +4,6 @@ from typing import Any from typing import Callable from typing import Iterable -from typing import Literal from typing import Mapping from typing import Sequence @@ -39,7 +38,11 @@ from narwhals._compliant import CompliantExpr from narwhals._compliant import CompliantNamespace from narwhals.dtypes import DType + from narwhals.typing import ClosedInterval + from narwhals.typing import FillNullStrategy from narwhals.typing import IntoExpr + from narwhals.typing import RankMethod + from narwhals.typing import RollingInterpolationMethod PS = ParamSpec("PS") R = TypeVar("R") @@ -1133,7 +1136,7 @@ def is_between( self: Self, lower_bound: Any | IntoExpr, upper_bound: Any | IntoExpr, - closed: Literal["left", "right", "none", "both"] = "both", + closed: ClosedInterval = "both", ) -> Self: """Check if this expression is between the given lower and upper bounds. @@ -1363,7 +1366,7 @@ def arg_true(self: Self) -> Self: def fill_null( self: Self, value: Expr | Any | None = None, - strategy: Literal["forward", "backward"] | None = None, + strategy: FillNullStrategy | None = None, limit: int | None = None, ) -> Self: """Fill null values with given value. @@ -1757,9 +1760,7 @@ def is_last_distinct(self: Self) -> Self: ) def quantile( - self: Self, - quantile: float, - interpolation: Literal["nearest", "higher", "lower", "midpoint", "linear"], + self: Self, quantile: float, interpolation: RollingInterpolationMethod ) -> Self: r"""Get quantile value. @@ -2481,10 +2482,7 @@ def rolling_std( ) def rank( - self: Self, - method: Literal["average", "min", "max", "dense", "ordinal"] = "average", - *, - descending: bool = False, + self: Self, method: RankMethod = "average", *, descending: bool = False ) -> Self: """Assign ranks to data, dealing with ties appropriately. diff --git a/narwhals/functions.py b/narwhals/functions.py index 91a782c62e..bd902716d3 100644 --- a/narwhals/functions.py +++ b/narwhals/functions.py @@ -57,6 +57,7 @@ from narwhals.dtypes import DType from narwhals.schema import Schema from narwhals.series import Series + from narwhals.typing import ConcatMethod from narwhals.typing import IntoExpr from narwhals.typing import IntoSeriesT from narwhals.typing import NativeFrame @@ -68,11 +69,7 @@ FrameT = TypeVar("FrameT", "DataFrame[Any]", "LazyFrame[Any]") -def concat( - items: Iterable[FrameT], - *, - how: Literal["horizontal", "vertical", "diagonal"] = "vertical", -) -> FrameT: +def concat(items: Iterable[FrameT], *, how: ConcatMethod = "vertical") -> FrameT: """Concatenate multiple DataFrames, LazyFrames into a single entity. Arguments: diff --git a/narwhals/series.py b/narwhals/series.py index d7620522e7..c3a57795b7 100644 --- a/narwhals/series.py +++ b/narwhals/series.py @@ -37,6 +37,10 @@ from narwhals._compliant import CompliantSeries from narwhals.dataframe import DataFrame from narwhals.dtypes import DType + from narwhals.typing import ClosedInterval + from narwhals.typing import FillNullStrategy + from narwhals.typing import RankMethod + from narwhals.typing import RollingInterpolationMethod from narwhals.typing import _1DArray from narwhals.utils import Implementation @@ -1279,7 +1283,7 @@ def is_nan(self: Self) -> Self: def fill_null( self: Self, value: Any | None = None, - strategy: Literal["forward", "backward"] | None = None, + strategy: FillNullStrategy | None = None, limit: int | None = None, ) -> Self: """Fill null values using the specified value. @@ -1338,7 +1342,7 @@ def is_between( self: Self, lower_bound: Any | Self, upper_bound: Any | Self, - closed: Literal["left", "right", "none", "both"] = "both", + closed: ClosedInterval = "both", ) -> Self: """Get a boolean mask of the values that are between the given lower/upper bounds. @@ -1804,9 +1808,7 @@ def value_counts( ) def quantile( - self: Self, - quantile: float, - interpolation: Literal["nearest", "higher", "lower", "midpoint", "linear"], + self: Self, quantile: float, interpolation: RollingInterpolationMethod ) -> float: """Get quantile value of the series. @@ -2485,10 +2487,7 @@ def __contains__(self: Self, other: Any) -> bool: return self._compliant_series.__contains__(other) def rank( - self: Self, - method: Literal["average", "min", "max", "dense", "ordinal"] = "average", - *, - descending: bool = False, + self: Self, method: RankMethod = "average", *, descending: bool = False ) -> Self: """Assign ranks to data, dealing with ties appropriately. diff --git a/narwhals/stable/v1/__init__.py b/narwhals/stable/v1/__init__.py index fa4f937ca9..8dfb4838e5 100644 --- a/narwhals/stable/v1/__init__.py +++ b/narwhals/stable/v1/__init__.py @@ -92,6 +92,7 @@ from narwhals._translate import IntoArrowTable from narwhals.dtypes import DType + from narwhals.typing import ConcatMethod from narwhals.typing import IntoExpr from narwhals.typing import IntoFrame from narwhals.typing import IntoLazyFrameT @@ -2059,11 +2060,7 @@ def max_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: return _stableify(nw.max_horizontal(*exprs)) -def concat( - items: Iterable[FrameT], - *, - how: Literal["horizontal", "vertical", "diagonal"] = "vertical", -) -> FrameT: +def concat(items: Iterable[FrameT], *, how: ConcatMethod = "vertical") -> FrameT: """Concatenate multiple DataFrames, LazyFrames into a single entity. Arguments: diff --git a/narwhals/typing.py b/narwhals/typing.py index 6f481768a5..ec7e7d5915 100644 --- a/narwhals/typing.py +++ b/narwhals/typing.py @@ -208,6 +208,90 @@ def __native_namespace__(self) -> ModuleType: ... TimeUnit: TypeAlias = Literal["ns", "us", "ms", "s"] +AsofJoinStrategy: TypeAlias = Literal["backward", "forward", "nearest"] +"""Join strategy. + +- *"backward"*: Selects the last row in the right DataFrame whose `on` key + is less than or equal to the left's key. +- *"forward"*: Selects the first row in the right DataFrame whose `on` key + is greater than or equal to the left's key. +- *"nearest"*: Search selects the last row in the right DataFrame whose value + is nearest to the left's key. +""" + +ClosedInterval: TypeAlias = Literal["left", "right", "none", "both"] +"""Define which sides of the interval are closed (inclusive).""" + +ConcatMethod: TypeAlias = Literal["horizontal", "vertical", "diagonal"] +"""Concatenating strategy. + +- *"vertical"*: Concatenate vertically. Column names must match. +- *"horizontal"*: Concatenate horizontally. If lengths don't match, then + missing rows are filled with null values. +- *"diagonal"*: Finds a union between the column schemas and fills missing + column values with null. +""" + +FillNullStrategy: TypeAlias = Literal["forward", "backward"] +"""Strategy used to fill null values.""" + +JoinStrategy: TypeAlias = Literal["inner", "left", "full", "cross", "semi", "anti"] +"""Join strategy. + +- *"inner"*: Returns rows that have matching values in both tables. +- *"left"*: Returns all rows from the left table, and the matched rows from + the right table. +- *"full"*: Returns all rows in both dataframes, with the `suffix` appended to + the right join keys. +- *"cross"*: Returns the Cartesian product of rows from both tables. +- *"semi"*: Filter rows that have a match in the right table. +- *"anti"*: Filter rows that do not have a match in the right table. +""" + +PivotAgg: TypeAlias = Literal[ + "min", "max", "first", "last", "sum", "mean", "median", "len" +] +"""A predefined aggregate function string.""" + +RankMethod: TypeAlias = Literal["average", "min", "max", "dense", "ordinal"] +"""The method used to assign ranks to tied elements. + +- *"average"*: The average of the ranks that would have been assigned to + all the tied values is assigned to each value. +- *"min"*: The minimum of the ranks that would have been assigned to all + the tied values is assigned to each value. (This is also referred to + as "competition" ranking.) +- *"max"*: The maximum of the ranks that would have been assigned to all + the tied values is assigned to each value. +- *"dense"*: Like "min", but the rank of the next highest element is + assigned the rank immediately after those assigned to the tied elements. +- *"ordinal"*: All values are given a distinct rank, corresponding to the + order that the values occur in the Series. +""" + +RollingInterpolationMethod: TypeAlias = Literal[ + "nearest", "higher", "lower", "midpoint", "linear" +] +"""Interpolation method.""" + +UniqueKeepStrategy: TypeAlias = Literal["any", "first", "last", "none"] +"""Which of the duplicate rows to keep. + +- *"any"*: Does not give any guarantee of which row is kept. + This allows more optimizations. +- *"none"*: Don't keep duplicate rows. +- *"first"*: Keep first unique row. +- *"last"*: Keep last unique row. +""" + +LazyUniqueKeepStrategy: TypeAlias = Literal["any", "none"] +"""Which of the duplicate rows to keep. + +- *"any"*: Does not give any guarantee of which row is kept. +- *"none"*: Don't keep duplicate rows. +""" + + _ShapeT = TypeVar("_ShapeT", bound="tuple[int, ...]") _NDArray: TypeAlias = "np.ndarray[_ShapeT, Any]" _1DArray: TypeAlias = "_NDArray[tuple[int]]" # noqa: PYI042