Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions narwhals/_arrow/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,9 +348,7 @@ def concat(

@property
def selectors(self: Self) -> ArrowSelectorNamespace:
return ArrowSelectorNamespace(
backend_version=self._backend_version, version=self._version
)
return ArrowSelectorNamespace(self)

def when(self: Self, predicate: ArrowExpr) -> ArrowWhen:
return ArrowWhen(predicate, self._backend_version, version=self._version)
Expand Down
117 changes: 36 additions & 81 deletions narwhals/_arrow/selectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
import re
from typing import TYPE_CHECKING
from typing import Any
from typing import Callable
from typing import Iterable
from typing import Sequence

from narwhals._arrow.expr import ArrowExpr
from narwhals.utils import Implementation
from narwhals.utils import _parse_time_unit_and_time_zone
from narwhals.utils import dtype_matches_time_unit_and_time_zone
from narwhals.utils import import_dtypes_module
Expand All @@ -21,16 +21,13 @@
from narwhals._arrow.series import ArrowSeries
from narwhals.dtypes import DType
from narwhals.typing import TimeUnit
from narwhals.utils import Version
from narwhals.utils import _LimitedContext


class ArrowSelectorNamespace:
def __init__(
self: Self, *, backend_version: tuple[int, ...], version: Version
) -> None:
self._backend_version = backend_version
self._implementation = Implementation.PYARROW
self._version = version
def __init__(self: Self, context: _LimitedContext, /) -> None:
self._backend_version = context._backend_version
self._version = context._version

def by_dtype(self: Self, dtypes: Iterable[DType | type[DType]]) -> ArrowSelector:
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
Expand All @@ -39,16 +36,7 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
def evaluate_output_names(df: ArrowDataFrame) -> Sequence[str]:
return [col for col in df.columns if df.schema[col] in dtypes]

return ArrowSelector(
func,
depth=0,
function_name="selector",
evaluate_output_names=evaluate_output_names,
alias_output_names=None,
backend_version=self._backend_version,
version=self._version,
kwargs={"dtypes": dtypes},
)
return selector(self, func, evaluate_output_names, {"dtypes": dtypes})

def matches(self: Self, pattern: str) -> ArrowSelector:
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
Expand All @@ -57,16 +45,7 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
def evaluate_output_names(df: ArrowDataFrame) -> Sequence[str]:
return [col for col in df.columns if re.search(pattern, col)]

return ArrowSelector(
func,
depth=0,
function_name="selector",
evaluate_output_names=evaluate_output_names,
alias_output_names=None,
backend_version=self._backend_version,
version=self._version,
kwargs={"pattern": pattern},
)
return selector(self, func, evaluate_output_names, {"pattern": pattern})

def numeric(self: Self) -> ArrowSelector:
dtypes = import_dtypes_module(self._version)
Expand Down Expand Up @@ -103,16 +82,7 @@ def all(self: Self) -> ArrowSelector:
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
return [df[col] for col in df.columns]

return ArrowSelector(
func,
depth=0,
function_name="selector",
evaluate_output_names=lambda df: df.columns,
alias_output_names=None,
backend_version=self._backend_version,
version=self._version,
kwargs={},
)
return selector(self, func, lambda df: df.columns, {})

def datetime(
self: Self,
Expand Down Expand Up @@ -148,16 +118,7 @@ def evaluate_output_names(df: ArrowDataFrame) -> Sequence[str]:
)
]

return ArrowSelector(
func,
depth=0,
function_name="selector",
evaluate_output_names=evaluate_output_names,
alias_output_names=None,
backend_version=self._backend_version,
version=self._version,
kwargs={},
)
return selector(self, func, evaluate_output_names, {})


class ArrowSelector(ArrowExpr):
Expand Down Expand Up @@ -190,15 +151,8 @@ def evaluate_output_names(df: ArrowDataFrame) -> list[str]:
rhs_names = other._evaluate_output_names(df)
return [x for x in lhs_names if x not in rhs_names]

return ArrowSelector(
call,
depth=0,
function_name="selector",
evaluate_output_names=evaluate_output_names,
alias_output_names=None,
backend_version=self._backend_version,
version=self._version,
kwargs={**self._kwargs, "other": other},
return selector(
self, call, evaluate_output_names, {**self._kwargs, "other": other}
)
else:
return self._to_expr() - other
Expand All @@ -221,15 +175,8 @@ def evaluate_output_names(df: ArrowDataFrame) -> list[str]:
rhs_names = other._evaluate_output_names(df)
return [*(x for x in lhs_names if x not in rhs_names), *rhs_names]

return ArrowSelector(
call,
depth=0,
function_name="selector",
evaluate_output_names=evaluate_output_names,
alias_output_names=None,
backend_version=self._backend_version,
version=self._version,
kwargs={**self._kwargs, "other": other},
return selector(
self, call, evaluate_output_names, {**self._kwargs, "other": other}
)
else:
return self._to_expr() | other
Expand All @@ -248,23 +195,31 @@ def evaluate_output_names(df: ArrowDataFrame) -> list[str]:
rhs_names = other._evaluate_output_names(df)
return [x for x in lhs_names if x in rhs_names]

return ArrowSelector(
call,
depth=0,
function_name="selector",
evaluate_output_names=evaluate_output_names,
alias_output_names=None,
backend_version=self._backend_version,
version=self._version,
kwargs={**self._kwargs, "other": other},
return selector(
self, call, evaluate_output_names, {**self._kwargs, "other": other}
)

else:
return self._to_expr() & other

def __invert__(self: Self) -> ArrowSelector:
return (
ArrowSelectorNamespace(
backend_version=self._backend_version, version=self._version
).all()
- self
)
return ArrowSelectorNamespace(self).all() - self


def selector(
context: _LimitedContext,
call: Callable[[ArrowDataFrame], Sequence[ArrowSeries]],
evaluate_output_names: Callable[[ArrowDataFrame], Sequence[str]],
kwargs: dict[str, Any],
/,
) -> ArrowSelector:
return ArrowSelector(
call,
depth=0,
function_name="selector",
evaluate_output_names=evaluate_output_names,
alias_output_names=None,
backend_version=context._backend_version,
version=context._version,
kwargs=kwargs,
)
4 changes: 1 addition & 3 deletions narwhals/_dask/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,7 @@
class DaskNamespace(CompliantNamespace["dx.Series"]):
@property
def selectors(self: Self) -> DaskSelectorNamespace:
return DaskSelectorNamespace(
backend_version=self._backend_version, version=self._version
)
return DaskSelectorNamespace(self)

def __init__(
self: Self, *, backend_version: tuple[int, ...], version: Version
Expand Down
113 changes: 31 additions & 82 deletions narwhals/_dask/selectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import re
from typing import TYPE_CHECKING
from typing import Any
from typing import Callable
from typing import Iterable
from typing import Sequence

Expand All @@ -24,7 +25,7 @@
from narwhals._dask.dataframe import DaskLazyFrame
from narwhals.dtypes import DType
from narwhals.typing import TimeUnit
from narwhals.utils import Version
from narwhals.utils import _LimitedContext

try:
import dask.dataframe.dask_expr as dx
Expand All @@ -33,11 +34,9 @@


class DaskSelectorNamespace:
def __init__(
self: Self, *, backend_version: tuple[int, ...], version: Version
) -> None:
self._backend_version = backend_version
self._version = version
def __init__(self: Self, context: _LimitedContext, /) -> None:
self._backend_version = context._backend_version
self._version = context._version

def by_dtype(self: Self, dtypes: Iterable[DType | type[DType]]) -> DaskSelector:
def func(df: DaskLazyFrame) -> list[dx.Series]:
Expand All @@ -48,16 +47,7 @@ def func(df: DaskLazyFrame) -> list[dx.Series]:
def evaluate_output_names(df: DaskLazyFrame) -> Sequence[str]:
return [col for col in df.columns if df.schema[col] in dtypes]

return DaskSelector(
func,
depth=0,
function_name="selector",
evaluate_output_names=evaluate_output_names,
alias_output_names=None,
backend_version=self._backend_version,
version=self._version,
kwargs={},
)
return selector(self, func, evaluate_output_names)

def matches(self: Self, pattern: str) -> DaskSelector:
def func(df: DaskLazyFrame) -> list[dx.Series]:
Expand All @@ -68,16 +58,7 @@ def func(df: DaskLazyFrame) -> list[dx.Series]:
def evaluate_output_names(df: DaskLazyFrame) -> Sequence[str]:
return [col for col in df.columns if re.search(pattern, col)]

return DaskSelector(
func,
depth=0,
function_name="selector",
evaluate_output_names=evaluate_output_names,
alias_output_names=None,
backend_version=self._backend_version,
version=self._version,
kwargs={},
)
return selector(self, func, evaluate_output_names)

def numeric(self: Self) -> DaskSelector:
dtypes = import_dtypes_module(self._version)
Expand Down Expand Up @@ -114,16 +95,7 @@ def all(self: Self) -> DaskSelector:
def func(df: DaskLazyFrame) -> list[dx.Series]:
return [df._native_frame[col] for col in df.columns]

return DaskSelector(
func,
depth=0,
function_name="selector",
evaluate_output_names=lambda df: df.columns,
alias_output_names=None,
backend_version=self._backend_version,
version=self._version,
kwargs={},
)
return selector(self, func, lambda df: df.columns)

def datetime(
self: Self,
Expand Down Expand Up @@ -159,16 +131,7 @@ def evaluate_output_names(df: DaskLazyFrame) -> Sequence[str]:
)
]

return DaskSelector(
func,
depth=0,
function_name="selector",
evaluate_output_names=evaluate_output_names,
alias_output_names=None,
backend_version=self._backend_version,
version=self._version,
kwargs={},
)
return selector(self, func, evaluate_output_names)


class DaskSelector(DaskExpr):
Expand Down Expand Up @@ -201,16 +164,7 @@ def evaluate_output_names(df: DaskLazyFrame) -> list[str]:
rhs_names = other._evaluate_output_names(df)
return [x for x in lhs_names if x not in rhs_names]

return DaskSelector(
call,
depth=0,
function_name="selector",
evaluate_output_names=evaluate_output_names,
alias_output_names=None,
backend_version=self._backend_version,
version=self._version,
kwargs={},
)
return selector(self, call, evaluate_output_names)
else:
return self._to_expr() - other

Expand All @@ -232,16 +186,7 @@ def evaluate_output_names(df: DaskLazyFrame) -> list[str]:
rhs_names = other._evaluate_output_names(df)
return [*(x for x in lhs_names if x not in rhs_names), *rhs_names]

return DaskSelector(
call,
depth=0,
function_name="selector",
evaluate_output_names=evaluate_output_names,
alias_output_names=None,
backend_version=self._backend_version,
version=self._version,
kwargs={},
)
return selector(self, call, evaluate_output_names)
else:
return self._to_expr() | other

Expand All @@ -259,23 +204,27 @@ def evaluate_output_names(df: DaskLazyFrame) -> list[str]:
rhs_names = other._evaluate_output_names(df)
return [x for x in lhs_names if x in rhs_names]

return DaskSelector(
call,
depth=0,
function_name="selector",
evaluate_output_names=evaluate_output_names,
alias_output_names=None,
backend_version=self._backend_version,
version=self._version,
kwargs={},
)
return selector(self, call, evaluate_output_names)
else:
return self._to_expr() & other

def __invert__(self: Self) -> DaskSelector:
return (
DaskSelectorNamespace(
backend_version=self._backend_version, version=self._version
).all()
- self
)
return DaskSelectorNamespace(self).all() - self


def selector(
context: _LimitedContext,
call: Callable[[DaskLazyFrame], Sequence[dx.Series]],
evaluate_output_names: Callable[[DaskLazyFrame], Sequence[str]],
/,
) -> DaskSelector:
return DaskSelector(
call,
depth=0,
function_name="selector",
evaluate_output_names=evaluate_output_names,
alias_output_names=None,
backend_version=context._backend_version,
version=context._version,
kwargs={},
)
Loading
Loading