Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion narwhals/_compliant/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,9 @@ def replace_strict(
*,
return_dtype: DType | type[DType] | None,
) -> Self: ...
def over(self: Self, keys: Sequence[str], order_by: Sequence[str] | None) -> Self: ...
def over(
self, partition_by: Sequence[str], order_by: Sequence[str] | None
) -> Self: ...
Comment on lines +171 to +173
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Weird that the keys -> partition_by rename didn't show up as an issue elsewhere πŸ€”

def sample(
self,
n: int | None,
Expand Down
78 changes: 72 additions & 6 deletions narwhals/_polars/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any
from typing import Callable
from typing import Literal
from typing import Mapping
from typing import Sequence

import polars as pl
Expand All @@ -18,6 +19,8 @@

from narwhals._expression_parsing import ExprKind
from narwhals._expression_parsing import ExprMetadata
from narwhals._polars.dataframe import Method
from narwhals._polars.namespace import PolarsNamespace
from narwhals.dtypes import DType
from narwhals.utils import Version

Expand Down Expand Up @@ -61,7 +64,7 @@ def _renamed_min_periods(self, min_samples: int, /) -> dict[str, Any]:
name = "min_periods" if self._backend_version < (1, 21, 0) else "min_samples"
return {name: min_samples}

def cast(self: Self, dtype: DType) -> Self:
def cast(self, dtype: DType | type[DType]) -> Self:
dtype_pl = narwhals_to_native_dtype(dtype, self._version, self._backend_version)
return self._with_native(self.native.cast(dtype_pl))

Expand Down Expand Up @@ -96,9 +99,7 @@ def is_nan(self: Self) -> Self:
native = pl.when(self.native.is_not_null()).then(self.native.is_nan())
return self._with_native(native)

def over(
self: Self, partition_by: Sequence[str], order_by: Sequence[str] | None
) -> Self:
def over(self, partition_by: Sequence[str], order_by: Sequence[str] | None) -> Self:
if self._backend_version < (1, 9):
if order_by:
msg = "`order_by` in Polars requires version 1.10 or greater"
Expand Down Expand Up @@ -147,7 +148,7 @@ def rolling_mean(
return self._with_native(native)

def map_batches(
self: Self, function: Callable[..., Self], return_dtype: DType | None
self, function: Callable[[Any], Any], return_dtype: DType | type[DType] | None
) -> Self:
return_dtype_pl = (
narwhals_to_native_dtype(return_dtype, self._version, self._backend_version)
Expand All @@ -158,7 +159,11 @@ def map_batches(
return self._with_native(native)

def replace_strict(
self: Self, old: Sequence[Any], new: Sequence[Any], *, return_dtype: DType | None
self,
old: Sequence[Any] | Mapping[Any, Any],
new: Sequence[Any],
*,
return_dtype: DType | type[DType] | None,
) -> Self:
if self._backend_version < (1,):
msg = f"`replace_strict` is only available in Polars>=1.0, found version {self._backend_version}"
Expand Down Expand Up @@ -226,6 +231,14 @@ def cum_count(self: Self, *, reverse: bool) -> Self:
result = self.native.cum_count(reverse=reverse)
return self._with_native(result)

def __narwhals_expr__(self) -> None: ...
def __narwhals_namespace__(self) -> PolarsNamespace: # pragma: no cover
from narwhals._polars.namespace import PolarsNamespace

return PolarsNamespace(
backend_version=self._backend_version, version=self._version
)

@property
def dt(self: Self) -> PolarsExprDateTimeNamespace:
return PolarsExprDateTimeNamespace(self)
Expand All @@ -250,6 +263,59 @@ def list(self: Self) -> PolarsExprListNamespace:
def struct(self: Self) -> PolarsExprStructNamespace:
return PolarsExprStructNamespace(self)

# CompliantExpr
_alias_output_names: Any
_evaluate_output_names: Any
_is_multi_output_unnamed: Any
__call__: Any
from_column_names: Any
from_column_indices: Any

# Polars
abs: Method[Self]
all: Method[Self]
any: Method[Self]
alias: Method[Self]
arg_max: Method[Self]
arg_min: Method[Self]
arg_true: Method[Self]
count: Method[Self]
cum_max: Method[Self]
cum_min: Method[Self]
cum_prod: Method[Self]
cum_sum: Method[Self]
diff: Method[Self]
drop_nulls: Method[Self]
fill_null: Method[Self]
gather_every: Method[Self]
head: Method[Self]
is_finite: Method[Self]
is_first_distinct: Method[Self]
is_in: Method[Self]
is_last_distinct: Method[Self]
is_null: Method[Self]
is_unique: Method[Self]
len: Method[Self]
max: Method[Self]
mean: Method[Self]
median: Method[Self]
min: Method[Self]
mode: Method[Self]
n_unique: Method[Self]
null_count: Method[Self]
quantile: Method[Self]
rank: Method[Self]
round: Method[Self]
sample: Method[Self]
shift: Method[Self]
skew: Method[Self]
std: Method[Self]
sum: Method[Self]
sort: Method[Self]
tail: Method[Self]
unique: Method[Self]
var: Method[Self]


class PolarsExprDateTimeNamespace:
def __init__(self: Self, expr: PolarsExpr) -> None:
Expand Down
17 changes: 9 additions & 8 deletions narwhals/_polars/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from datetime import timezone

from typing_extensions import Self
from typing_extensions import TypeAlias

from narwhals._compliant import CompliantSelectorNamespace
from narwhals._compliant import CompliantWhen
Expand All @@ -39,8 +38,6 @@
from narwhals.utils import Version
from narwhals.utils import _FullContext

Incomplete: TypeAlias = Any


class PolarsNamespace:
all: Method[PolarsExpr]
Expand All @@ -51,8 +48,10 @@ class PolarsNamespace:
sum_horizontal: Method[PolarsExpr]
min_horizontal: Method[PolarsExpr]
max_horizontal: Method[PolarsExpr]
# NOTE: `PolarsSeries`, `PolarsExpr` still have gaps
when: Method[CompliantWhen[PolarsDataFrame, Incomplete, Incomplete]]

# NOTE: `pyright` accepts, `mypy` doesn't highlight the issue
# error: Type argument "PolarsExpr" of "CompliantWhen" must be a subtype of "CompliantExpr[Any, Any]"
when: Method[CompliantWhen[PolarsDataFrame, PolarsSeries, PolarsExpr]] # type: ignore[type-var]

def __init__(
self: Self, *, backend_version: tuple[int, ...], version: Version
Expand Down Expand Up @@ -231,10 +230,12 @@ def concat_str(
# 1. Others have lots of private stuff for code reuse
# i. None of that is useful here
# 2. We don't have a `PolarsSelector` abstraction, and just use `PolarsExpr`
# 3. `PolarsExpr` still has it's own gaps in the spec
@property
def selectors(self: Self) -> CompliantSelectorNamespace[Any, Any]:
return cast("CompliantSelectorNamespace[Any, Any]", PolarsSelectorNamespace(self))
def selectors(self) -> CompliantSelectorNamespace[PolarsDataFrame, PolarsSeries]:
return cast(
"CompliantSelectorNamespace[PolarsDataFrame, PolarsSeries]",
PolarsSelectorNamespace(self),
)


class PolarsSelectorNamespace:
Expand Down
81 changes: 75 additions & 6 deletions narwhals/_polars/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from typing import TYPE_CHECKING
from typing import Any
from typing import Iterable
from typing import Iterator
from typing import Mapping
from typing import Sequence
from typing import cast
from typing import overload
Expand All @@ -22,8 +24,11 @@
from types import ModuleType
from typing import TypeVar

import pandas as pd
from typing_extensions import Self

from narwhals._arrow.typing import ArrowArray
from narwhals._polars.dataframe import Method
from narwhals._polars.dataframe import PolarsDataFrame
from narwhals._polars.expr import PolarsExpr
from narwhals._polars.namespace import PolarsNamespace
Expand Down Expand Up @@ -177,12 +182,16 @@ def __getitem__(
) -> Any | Self:
return self._from_native_object(self.native.__getitem__(item))

def cast(self: Self, dtype: DType) -> Self:
def cast(self: Self, dtype: DType | type[DType]) -> Self:
dtype_pl = narwhals_to_native_dtype(dtype, self._version, self._backend_version)
return self._with_native(self.native.cast(dtype_pl))

def replace_strict(
self: Self, old: Sequence[Any], new: Sequence[Any], *, return_dtype: DType | None
self: Self,
old: Sequence[Any] | Mapping[Any, Any],
new: Sequence[Any],
*,
return_dtype: DType | type[DType] | None,
) -> Self:
ser = self.native
dtype = (
Expand Down Expand Up @@ -586,14 +595,74 @@ def str(self: Self) -> PolarsSeriesStringNamespace:
def cat(self: Self) -> PolarsSeriesCatNamespace:
return PolarsSeriesCatNamespace(self)

@property
def list(self: Self) -> PolarsSeriesListNamespace:
return PolarsSeriesListNamespace(self)

@property
def struct(self: Self) -> PolarsSeriesStructNamespace:
return PolarsSeriesStructNamespace(self)

__iter__: Method[Iterator[Any]]
__floordiv__: Method[Self]
__mod__: Method[Self]
__rand__: Method[Self]
__rfloordiv__: Method[Self]
__rmod__: Method[Self]
__ror__: Method[Self]
__rtruediv__: Method[Self]
__truediv__: Method[Self]
abs: Method[Self]
all: Method[bool]
any: Method[bool]
arg_max: Method[int]
arg_min: Method[int]
arg_true: Method[Self]
clip: Method[Self]
count: Method[int]
cum_max: Method[Self]
cum_min: Method[Self]
cum_prod: Method[Self]
cum_sum: Method[Self]
diff: Method[Self]
drop_nulls: Method[Self]
fill_null: Method[Self]
filter: Method[Self]
gather_every: Method[Self]
head: Method[Self]
is_between: Method[Self]
is_finite: Method[Self]
is_first_distinct: Method[Self]
is_in: Method[Self]
is_last_distinct: Method[Self]
is_null: Method[Self]
is_sorted: Method[bool]
is_unique: Method[Self]
item: Method[Any]
len: Method[int]
max: Method[Any]
mean: Method[float]
min: Method[Any]
mode: Method[Self]
n_unique: Method[int]
null_count: Method[int]
quantile: Method[float]
rank: Method[Self]
round: Method[Self]
sample: Method[Self]
shift: Method[Self]
skew: Method[float | None]
std: Method[float]
sum: Method[float]
tail: Method[Self]
to_arrow: Method[ArrowArray]
to_frame: Method[PolarsDataFrame]
to_list: Method[list[Any]]
to_pandas: Method[pd.Series[Any]]
unique: Method[Self]
var: Method[float]
zip_with: Method[Self]

@property
def list(self: Self) -> PolarsSeriesListNamespace:
return PolarsSeriesListNamespace(self)
Comment on lines +656 to +664
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.



class PolarsSeriesDateTimeNamespace:
def __init__(self: Self, series: PolarsSeries) -> None:
Expand Down
Loading