diff --git a/narwhals/_arrow/expr_cat.py b/narwhals/_arrow/expr_cat.py index 6a26ee97f6..dbea11c59f 100644 --- a/narwhals/_arrow/expr_cat.py +++ b/narwhals/_arrow/expr_cat.py @@ -2,6 +2,7 @@ from typing import TYPE_CHECKING +from narwhals._arrow.utils import ArrowExprNamespace from narwhals._expression_parsing import reuse_series_namespace_implementation if TYPE_CHECKING: @@ -10,11 +11,8 @@ from narwhals._arrow.expr import ArrowExpr -class ArrowExprCatNamespace: - def __init__(self: Self, expr: ArrowExpr) -> None: - self._compliant_expr = expr - +class ArrowExprCatNamespace(ArrowExprNamespace): def get_categories(self: Self) -> ArrowExpr: return reuse_series_namespace_implementation( - self._compliant_expr, "cat", "get_categories" + self.compliant, "cat", "get_categories" ) diff --git a/narwhals/_arrow/expr_dt.py b/narwhals/_arrow/expr_dt.py index 30d2e22c8f..a090d74920 100644 --- a/narwhals/_arrow/expr_dt.py +++ b/narwhals/_arrow/expr_dt.py @@ -2,6 +2,7 @@ from typing import TYPE_CHECKING +from narwhals._arrow.utils import ArrowExprNamespace from narwhals._expression_parsing import reuse_series_namespace_implementation if TYPE_CHECKING: @@ -11,97 +12,84 @@ from narwhals.typing import TimeUnit -class ArrowExprDateTimeNamespace: - def __init__(self: Self, expr: ArrowExpr) -> None: - self._compliant_expr = expr - +class ArrowExprDateTimeNamespace(ArrowExprNamespace): def to_string(self: Self, format: str) -> ArrowExpr: # noqa: A002 return reuse_series_namespace_implementation( - self._compliant_expr, "dt", "to_string", format=format + self.compliant, "dt", "to_string", format=format ) def replace_time_zone(self: Self, time_zone: str | None) -> ArrowExpr: return reuse_series_namespace_implementation( - self._compliant_expr, "dt", "replace_time_zone", time_zone=time_zone + self.compliant, "dt", "replace_time_zone", time_zone=time_zone ) def convert_time_zone(self: Self, time_zone: str) -> ArrowExpr: return reuse_series_namespace_implementation( - self._compliant_expr, "dt", "convert_time_zone", time_zone=time_zone + self.compliant, "dt", "convert_time_zone", time_zone=time_zone ) def timestamp(self: Self, time_unit: TimeUnit) -> ArrowExpr: return reuse_series_namespace_implementation( - self._compliant_expr, "dt", "timestamp", time_unit=time_unit + self.compliant, "dt", "timestamp", time_unit=time_unit ) def date(self: Self) -> ArrowExpr: - return reuse_series_namespace_implementation(self._compliant_expr, "dt", "date") + return reuse_series_namespace_implementation(self.compliant, "dt", "date") def year(self: Self) -> ArrowExpr: - return reuse_series_namespace_implementation(self._compliant_expr, "dt", "year") + return reuse_series_namespace_implementation(self.compliant, "dt", "year") def month(self: Self) -> ArrowExpr: - return reuse_series_namespace_implementation(self._compliant_expr, "dt", "month") + return reuse_series_namespace_implementation(self.compliant, "dt", "month") def day(self: Self) -> ArrowExpr: - return reuse_series_namespace_implementation(self._compliant_expr, "dt", "day") + return reuse_series_namespace_implementation(self.compliant, "dt", "day") def hour(self: Self) -> ArrowExpr: - return reuse_series_namespace_implementation(self._compliant_expr, "dt", "hour") + return reuse_series_namespace_implementation(self.compliant, "dt", "hour") def minute(self: Self) -> ArrowExpr: - return reuse_series_namespace_implementation(self._compliant_expr, "dt", "minute") + return reuse_series_namespace_implementation(self.compliant, "dt", "minute") def second(self: Self) -> ArrowExpr: - return reuse_series_namespace_implementation(self._compliant_expr, "dt", "second") + return reuse_series_namespace_implementation(self.compliant, "dt", "second") def millisecond(self: Self) -> ArrowExpr: - return reuse_series_namespace_implementation( - self._compliant_expr, "dt", "millisecond" - ) + return reuse_series_namespace_implementation(self.compliant, "dt", "millisecond") def microsecond(self: Self) -> ArrowExpr: - return reuse_series_namespace_implementation( - self._compliant_expr, "dt", "microsecond" - ) + return reuse_series_namespace_implementation(self.compliant, "dt", "microsecond") def nanosecond(self: Self) -> ArrowExpr: - return reuse_series_namespace_implementation( - self._compliant_expr, "dt", "nanosecond" - ) + return reuse_series_namespace_implementation(self.compliant, "dt", "nanosecond") def ordinal_day(self: Self) -> ArrowExpr: - return reuse_series_namespace_implementation( - self._compliant_expr, "dt", "ordinal_day" - ) + return reuse_series_namespace_implementation(self.compliant, "dt", "ordinal_day") def weekday(self: Self) -> ArrowExpr: - return reuse_series_namespace_implementation( - self._compliant_expr, "dt", "weekday" - ) + return reuse_series_namespace_implementation(self.compliant, "dt", "weekday") def total_minutes(self: Self) -> ArrowExpr: return reuse_series_namespace_implementation( - self._compliant_expr, "dt", "total_minutes" + self.compliant, "dt", "total_minutes" ) def total_seconds(self: Self) -> ArrowExpr: return reuse_series_namespace_implementation( - self._compliant_expr, "dt", "total_seconds" + self.compliant, "dt", "total_seconds" ) def total_milliseconds(self: Self) -> ArrowExpr: return reuse_series_namespace_implementation( - self._compliant_expr, "dt", "total_milliseconds" + self.compliant, "dt", "total_milliseconds" ) def total_microseconds(self: Self) -> ArrowExpr: return reuse_series_namespace_implementation( - self._compliant_expr, "dt", "total_microseconds" + self.compliant, "dt", "total_microseconds" ) def total_nanoseconds(self: Self) -> ArrowExpr: return reuse_series_namespace_implementation( - self._compliant_expr, "dt", "total_nanoseconds" + self.compliant, "dt", "total_nanoseconds" ) diff --git a/narwhals/_arrow/expr_list.py b/narwhals/_arrow/expr_list.py index 8e8e4c1f00..14c81f5935 100644 --- a/narwhals/_arrow/expr_list.py +++ b/narwhals/_arrow/expr_list.py @@ -2,6 +2,7 @@ from typing import TYPE_CHECKING +from narwhals._arrow.utils import ArrowExprNamespace from narwhals._expression_parsing import reuse_series_namespace_implementation if TYPE_CHECKING: @@ -10,9 +11,6 @@ from narwhals._arrow.expr import ArrowExpr -class ArrowExprListNamespace: - def __init__(self: Self, expr: ArrowExpr) -> None: - self._expr = expr - +class ArrowExprListNamespace(ArrowExprNamespace): def len(self: Self) -> ArrowExpr: - return reuse_series_namespace_implementation(self._expr, "list", "len") + return reuse_series_namespace_implementation(self.compliant, "list", "len") diff --git a/narwhals/_arrow/expr_name.py b/narwhals/_arrow/expr_name.py index 9a7b54919a..2f8b368466 100644 --- a/narwhals/_arrow/expr_name.py +++ b/narwhals/_arrow/expr_name.py @@ -4,16 +4,15 @@ from typing import Callable from typing import Sequence +from narwhals._arrow.utils import ArrowExprNamespace + if TYPE_CHECKING: from typing_extensions import Self from narwhals._arrow.expr import ArrowExpr -class ArrowExprNameNamespace: - def __init__(self: Self, expr: ArrowExpr) -> None: - self._compliant_expr = expr - +class ArrowExprNameNamespace(ArrowExprNamespace): def keep(self: Self) -> ArrowExpr: return self._from_colname_func_and_alias_output_names( name_mapping_func=lambda name: name, @@ -65,19 +64,18 @@ def _from_colname_func_and_alias_output_names( name_mapping_func: Callable[[str], str], alias_output_names: Callable[[Sequence[str]], Sequence[str]] | None, ) -> ArrowExpr: - return self._compliant_expr.__class__( + return self.compliant.__class__( call=lambda df: [ series.alias(name_mapping_func(name)) for series, name in zip( - self._compliant_expr._call(df), - self._compliant_expr._evaluate_output_names(df), + self.compliant._call(df), self.compliant._evaluate_output_names(df) ) ], - depth=self._compliant_expr._depth, - function_name=self._compliant_expr._function_name, - evaluate_output_names=self._compliant_expr._evaluate_output_names, + depth=self.compliant._depth, + function_name=self.compliant._function_name, + evaluate_output_names=self.compliant._evaluate_output_names, alias_output_names=alias_output_names, - backend_version=self._compliant_expr._backend_version, - version=self._compliant_expr._version, - call_kwargs=self._compliant_expr._call_kwargs, + backend_version=self.compliant._backend_version, + version=self.compliant._version, + call_kwargs=self.compliant._call_kwargs, ) diff --git a/narwhals/_arrow/expr_str.py b/narwhals/_arrow/expr_str.py index 11ba75914b..fbabfb551f 100644 --- a/narwhals/_arrow/expr_str.py +++ b/narwhals/_arrow/expr_str.py @@ -2,6 +2,7 @@ from typing import TYPE_CHECKING +from narwhals._arrow.utils import ArrowExprNamespace from narwhals._expression_parsing import reuse_series_namespace_implementation if TYPE_CHECKING: @@ -10,20 +11,15 @@ from narwhals._arrow.expr import ArrowExpr -class ArrowExprStringNamespace: - def __init__(self: Self, expr: ArrowExpr) -> None: - self._compliant_expr = expr - +class ArrowExprStringNamespace(ArrowExprNamespace): def len_chars(self: Self) -> ArrowExpr: - return reuse_series_namespace_implementation( - self._compliant_expr, "str", "len_chars" - ) + return reuse_series_namespace_implementation(self.compliant, "str", "len_chars") def replace( self: Self, pattern: str, value: str, *, literal: bool, n: int ) -> ArrowExpr: return reuse_series_namespace_implementation( - self._compliant_expr, + self.compliant, "str", "replace", pattern=pattern, @@ -34,7 +30,7 @@ def replace( def replace_all(self: Self, pattern: str, value: str, *, literal: bool) -> ArrowExpr: return reuse_series_namespace_implementation( - self._compliant_expr, + self.compliant, "str", "replace_all", pattern=pattern, @@ -44,27 +40,27 @@ def replace_all(self: Self, pattern: str, value: str, *, literal: bool) -> Arrow def strip_chars(self: Self, characters: str | None) -> ArrowExpr: return reuse_series_namespace_implementation( - self._compliant_expr, "str", "strip_chars", characters=characters + self.compliant, "str", "strip_chars", characters=characters ) def starts_with(self: Self, prefix: str) -> ArrowExpr: return reuse_series_namespace_implementation( - self._compliant_expr, "str", "starts_with", prefix=prefix + self.compliant, "str", "starts_with", prefix=prefix ) def ends_with(self: Self, suffix: str) -> ArrowExpr: return reuse_series_namespace_implementation( - self._compliant_expr, "str", "ends_with", suffix=suffix + self.compliant, "str", "ends_with", suffix=suffix ) def contains(self: Self, pattern: str, *, literal: bool) -> ArrowExpr: return reuse_series_namespace_implementation( - self._compliant_expr, "str", "contains", pattern=pattern, literal=literal + self.compliant, "str", "contains", pattern=pattern, literal=literal ) def slice(self: Self, offset: int, length: int | None) -> ArrowExpr: return reuse_series_namespace_implementation( - self._compliant_expr, "str", "slice", offset=offset, length=length + self.compliant, "str", "slice", offset=offset, length=length ) def split(self: Self, by: str) -> ArrowExpr: @@ -74,15 +70,15 @@ def split(self: Self, by: str) -> ArrowExpr: def to_datetime(self: Self, format: str | None) -> ArrowExpr: # noqa: A002 return reuse_series_namespace_implementation( - self._compliant_expr, "str", "to_datetime", format=format + self.compliant, "str", "to_datetime", format=format ) def to_uppercase(self: Self) -> ArrowExpr: return reuse_series_namespace_implementation( - self._compliant_expr, "str", "to_uppercase" + self.compliant, "str", "to_uppercase" ) def to_lowercase(self: Self) -> ArrowExpr: return reuse_series_namespace_implementation( - self._compliant_expr, "str", "to_lowercase" + self.compliant, "str", "to_lowercase" ) diff --git a/narwhals/_arrow/series.py b/narwhals/_arrow/series.py index 4ed21d76f4..0b33da9d86 100644 --- a/narwhals/_arrow/series.py +++ b/narwhals/_arrow/series.py @@ -28,6 +28,7 @@ from narwhals.exceptions import InvalidOperationError from narwhals.typing import CompliantSeries from narwhals.utils import Implementation +from narwhals.utils import _StoresNative from narwhals.utils import generate_temporary_column_name from narwhals.utils import import_dtypes_module from narwhals.utils import validate_backend_version @@ -94,7 +95,7 @@ def maybe_extract_py_scalar(value: Any, return_py_scalar: bool) -> Any: # noqa: return value -class ArrowSeries(CompliantSeries): +class ArrowSeries(CompliantSeries, _StoresNative["ArrowChunkedArray"]): def __init__( self: Self, native_series: ArrowChunkedArray, @@ -113,7 +114,7 @@ def __init__( def _change_version(self: Self, version: Version) -> Self: return self.__class__( - self._native_series, + self.native, name=self._name, backend_version=self._backend_version, version=version, @@ -154,7 +155,7 @@ def __narwhals_namespace__(self: Self) -> ArrowNamespace: ) def __len__(self: Self) -> int: - return len(self._native_series) + return len(self.native) def __eq__(self: Self, other: object) -> Self: # type: ignore[override] ser, other = extract_native(self, other) @@ -260,16 +261,18 @@ def __rmod__(self: Self, other: Any) -> Self: return self._from_native_series(res) def __invert__(self: Self) -> Self: - return self._from_native_series( - pc.invert(self._native_series) # type: ignore[call-overload] - ) + return self._from_native_series(pc.invert(self.native)) # type: ignore[call-overload] @property def _type(self: Self) -> pa.DataType: - return self._native_series.type + return self.native.type + + @property + def native(self) -> ArrowChunkedArray: + return self._native_series def len(self: Self, *, _return_py_scalar: bool = True) -> int: - return maybe_extract_py_scalar(len(self._native_series), _return_py_scalar) + return maybe_extract_py_scalar(len(self.native), _return_py_scalar) def filter(self: Self, predicate: ArrowSeries | list[bool | None]) -> Self: if not ( @@ -278,13 +281,13 @@ def filter(self: Self, predicate: ArrowSeries | list[bool | None]) -> Self: _, other_native = extract_native(self, predicate) else: other_native = predicate - return self._from_native_series(self._native_series.filter(other_native)) # pyright: ignore[reportArgumentType] + return self._from_native_series(self.native.filter(other_native)) # pyright: ignore[reportArgumentType] def mean(self: Self, *, _return_py_scalar: bool = True) -> float: # NOTE: stub overly strict https://github.com/zen-xu/pyarrow-stubs/blob/d97063876720e6a5edda7eb15f4efe07c31b8296/pyarrow-stubs/compute.pyi#L274-L307 # docs say numeric https://arrow.apache.org/docs/python/generated/pyarrow.compute.mean.html mean: Incomplete = pc.mean - return maybe_extract_py_scalar(mean(self._native_series), _return_py_scalar) + return maybe_extract_py_scalar(mean(self.native), _return_py_scalar) def median(self: Self, *, _return_py_scalar: bool = True) -> float: from narwhals.exceptions import InvalidOperationError @@ -294,55 +297,53 @@ def median(self: Self, *, _return_py_scalar: bool = True) -> float: raise InvalidOperationError(msg) return maybe_extract_py_scalar( - pc.approximate_median(self._native_series), _return_py_scalar + pc.approximate_median(self.native), _return_py_scalar ) def min(self: Self, *, _return_py_scalar: bool = True) -> Any: - return maybe_extract_py_scalar(pc.min(self._native_series), _return_py_scalar) + return maybe_extract_py_scalar(pc.min(self.native), _return_py_scalar) def max(self: Self, *, _return_py_scalar: bool = True) -> Any: - return maybe_extract_py_scalar(pc.max(self._native_series), _return_py_scalar) + return maybe_extract_py_scalar(pc.max(self.native), _return_py_scalar) def arg_min(self: Self, *, _return_py_scalar: bool = True) -> int: - index_min = pc.index(self._native_series, pc.min(self._native_series)) + index_min = pc.index(self.native, pc.min(self.native)) return maybe_extract_py_scalar(index_min, _return_py_scalar) def arg_max(self: Self, *, _return_py_scalar: bool = True) -> int: - index_max = pc.index(self._native_series, pc.max(self._native_series)) + index_max = pc.index(self.native, pc.max(self.native)) return maybe_extract_py_scalar(index_max, _return_py_scalar) def sum(self: Self, *, _return_py_scalar: bool = True) -> float: return maybe_extract_py_scalar( - pc.sum(self._native_series, min_count=0), _return_py_scalar + pc.sum(self.native, min_count=0), _return_py_scalar ) def drop_nulls(self: Self) -> Self: - return self._from_native_series(self._native_series.drop_null()) + return self._from_native_series(self.native.drop_null()) def shift(self: Self, n: int) -> Self: - ca = self._native_series if n > 0: - arrays = [nulls_like(n, self), *ca[:-n].chunks] + arrays = [nulls_like(n, self), *self.native[:-n].chunks] elif n < 0: - arrays = [*ca[-n:].chunks, nulls_like(-n, self)] + arrays = [*self.native[-n:].chunks, nulls_like(-n, self)] else: - return self._from_native_series(ca) + return self._from_native_series(self.native) return self._from_native_series(pa.concat_arrays(arrays)) def std(self: Self, ddof: int, *, _return_py_scalar: bool = True) -> float: return maybe_extract_py_scalar( - pc.stddev(self._native_series, ddof=ddof), _return_py_scalar + pc.stddev(self.native, ddof=ddof), _return_py_scalar ) def var(self: Self, ddof: int, *, _return_py_scalar: bool = True) -> float: return maybe_extract_py_scalar( - pc.variance(self._native_series, ddof=ddof), _return_py_scalar + pc.variance(self.native, ddof=ddof), _return_py_scalar ) def skew(self: Self, *, _return_py_scalar: bool = True) -> float | None: - ser = self._native_series # NOTE: stub issue with `pc.subtract`, `pc.mean` and `pa.ChunkedArray` - ser_not_null: Incomplete = ser.drop_null() + ser_not_null: Incomplete = self.native.drop_null() if len(ser_not_null) == 0: return None elif len(ser_not_null) == 1: @@ -359,12 +360,11 @@ def skew(self: Self, *, _return_py_scalar: bool = True) -> float | None: return maybe_extract_py_scalar(biased_population_skewness, _return_py_scalar) def count(self: Self, *, _return_py_scalar: bool = True) -> int: - return maybe_extract_py_scalar(pc.count(self._native_series), _return_py_scalar) + return maybe_extract_py_scalar(pc.count(self.native), _return_py_scalar) def n_unique(self: Self, *, _return_py_scalar: bool = True) -> int: - unique_values = self._native_series.unique() return maybe_extract_py_scalar( - pc.count(unique_values, mode="all"), _return_py_scalar + pc.count(self.native.unique(), mode="all"), _return_py_scalar ) def __native_namespace__(self: Self) -> ModuleType: @@ -393,14 +393,10 @@ def __getitem__( self: Self, idx: int | slice | Sequence[int] | ArrowChunkedArray ) -> Any | Self: if isinstance(idx, int): - return maybe_extract_py_scalar( - self._native_series[idx], return_py_scalar=True - ) + return maybe_extract_py_scalar(self.native[idx], return_py_scalar=True) if isinstance(idx, (Sequence, pa.ChunkedArray)): - return self._from_native_series( - self._native_series.take(cast("Indices", idx)) - ) - return self._from_native_series(self._native_series[idx]) + return self._from_native_series(self.native.take(cast("Indices", idx))) + return self._from_native_series(self.native[idx]) def scatter(self: Self, indices: int | Sequence[int], values: Any) -> Self: import numpy as np # ignore-banned-import @@ -413,7 +409,7 @@ def scatter(self: Self, indices: int | Sequence[int], values: Any) -> Self: # https://github.com/narwhals-dev/narwhals/issues/2155 indices_native = pa.array(indices) if isinstance(values, self.__class__): - values_native = values._native_series.combine_chunks() + values_native = values.native.combine_chunks() else: values_native = pa.array(values) @@ -424,24 +420,24 @@ def scatter(self: Self, indices: int | Sequence[int], values: Any) -> Self: mask: _1DArray = np.zeros(self.len(), dtype=bool) mask[indices_native] = True result = pc.replace_with_mask( - self._native_series, + self.native, cast("list[bool]", mask), values_native.take(cast("Indices", indices_native)), ) return self._from_native_series(result) def to_list(self: Self) -> list[Any]: - return self._native_series.to_pylist() + return self.native.to_pylist() def __array__(self: Self, dtype: Any = None, *, copy: bool | None = None) -> _1DArray: - return self._native_series.__array__(dtype=dtype, copy=copy) + return self.native.__array__(dtype=dtype, copy=copy) def to_numpy(self: Self) -> _1DArray: - return self._native_series.to_numpy() + return self.native.to_numpy() def alias(self: Self, name: str) -> Self: return self.__class__( - self._native_series, + self.native, name=name, backend_version=self._backend_version, version=self._version, @@ -449,20 +445,19 @@ def alias(self: Self, name: str) -> Self: @property def dtype(self: Self) -> DType: - return native_to_narwhals_dtype(self._native_series.type, self._version) + return native_to_narwhals_dtype(self.native.type, self._version) def abs(self: Self) -> Self: - return self._from_native_series(pc.abs(self._native_series)) + return self._from_native_series(pc.abs(self.native)) def cum_sum(self: Self, *, reverse: bool) -> Self: - native_series = self._native_series # NOTE: stub only permits `NumericArray` # https://github.com/zen-xu/pyarrow-stubs/blob/d97063876720e6a5edda7eb15f4efe07c31b8296/pyarrow-stubs/compute.pyi#L140 cum_sum: Incomplete = pc.cumulative_sum result = ( - cum_sum(native_series, skip_nulls=True) + cum_sum(self.native, skip_nulls=True) if not reverse - else cum_sum(native_series[::-1], skip_nulls=True)[::-1] + else cum_sum(self.native[::-1], skip_nulls=True)[::-1] ) return self._from_native_series(result) @@ -471,28 +466,28 @@ def round(self: Self, decimals: int) -> Self: # https://github.com/zen-xu/pyarrow-stubs/blob/d97063876720e6a5edda7eb15f4efe07c31b8296/pyarrow-stubs/compute.pyi#L140 pc_round: Incomplete = pc.round return self._from_native_series( - pc_round(self._native_series, decimals, round_mode="half_towards_infinity") + pc_round(self.native, decimals, round_mode="half_towards_infinity") ) def diff(self: Self) -> Self: # NOTE: stub only permits `ChunkedArray[TemporalScalar]` # (https://github.com/zen-xu/pyarrow-stubs/blob/d97063876720e6a5edda7eb15f4efe07c31b8296/pyarrow-stubs/compute.pyi#L145-L148) diff: Incomplete = pc.pairwise_diff - return self._from_native_series(diff(self._native_series.combine_chunks())) + return self._from_native_series(diff(self.native.combine_chunks())) def any(self: Self, *, _return_py_scalar: bool = True) -> bool: # NOTE: stub restricts to `BooleanArray`, should be based on truthiness # Copies `pc.all` pc_any: Incomplete = pc.any return maybe_extract_py_scalar( - pc_any(self._native_series, min_count=0), _return_py_scalar + pc_any(self.native, min_count=0), _return_py_scalar ) def all(self: Self, *, _return_py_scalar: bool = True) -> bool: # NOTE: stub restricts to `BooleanArray`, should be based on truthiness pc_all: Incomplete = pc.all return maybe_extract_py_scalar( - pc_all(self._native_series, min_count=0), _return_py_scalar + pc_all(self.native, min_count=0), _return_py_scalar ) def is_between( @@ -501,73 +496,66 @@ def is_between( upper_bound: Any, closed: Literal["left", "right", "none", "both"], ) -> Self: - ser = self._native_series _, lower_bound = extract_native(self, lower_bound) _, upper_bound = extract_native(self, upper_bound) if closed == "left": - ge = pc.greater_equal(ser, lower_bound) - lt = pc.less(ser, upper_bound) + ge = pc.greater_equal(self.native, lower_bound) + lt = pc.less(self.native, upper_bound) res = pc.and_kleene(ge, lt) elif closed == "right": - gt = pc.greater(ser, lower_bound) - le = pc.less_equal(ser, upper_bound) + gt = pc.greater(self.native, lower_bound) + le = pc.less_equal(self.native, upper_bound) res = pc.and_kleene(gt, le) elif closed == "none": - gt = pc.greater(ser, lower_bound) - lt = pc.less(ser, upper_bound) + gt = pc.greater(self.native, lower_bound) + lt = pc.less(self.native, upper_bound) res = pc.and_kleene(gt, lt) elif closed == "both": - ge = pc.greater_equal(ser, lower_bound) - le = pc.less_equal(ser, upper_bound) + ge = pc.greater_equal(self.native, lower_bound) + le = pc.less_equal(self.native, upper_bound) res = pc.and_kleene(ge, le) else: # pragma: no cover raise AssertionError return self._from_native_series(res) def is_null(self: Self) -> Self: - ser = self._native_series - return self._from_native_series(ser.is_null()) + return self._from_native_series(self.native.is_null()) def is_nan(self: Self) -> Self: - return self._from_native_series(pc.is_nan(self._native_series)) + return self._from_native_series(pc.is_nan(self.native)) def cast(self: Self, dtype: DType | type[DType]) -> Self: - ser = self._native_series data_type = narwhals_to_native_dtype(dtype, self._version) - return self._from_native_series(pc.cast(ser, data_type)) + return self._from_native_series(pc.cast(self.native, data_type)) def null_count(self: Self, *, _return_py_scalar: bool = True) -> int: - return maybe_extract_py_scalar(self._native_series.null_count, _return_py_scalar) + return maybe_extract_py_scalar(self.native.null_count, _return_py_scalar) def head(self: Self, n: int) -> Self: - ser = self._native_series if n >= 0: - return self._from_native_series(ser.slice(0, n)) + return self._from_native_series(self.native.slice(0, n)) else: - num_rows = len(ser) - return self._from_native_series(ser.slice(0, max(0, num_rows + n))) + num_rows = len(self) + return self._from_native_series(self.native.slice(0, max(0, num_rows + n))) def tail(self: Self, n: int) -> Self: - ser = self._native_series if n >= 0: - num_rows = len(ser) - return self._from_native_series(ser.slice(max(0, num_rows - n))) + num_rows = len(self) + return self._from_native_series(self.native.slice(max(0, num_rows - n))) else: - return self._from_native_series(ser.slice(abs(n))) + return self._from_native_series(self.native.slice(abs(n))) def is_in(self: Self, other: Any) -> Self: if isinstance(other, pa.ChunkedArray): value_set: ArrowChunkedArray | ArrowArray = other else: value_set = pa.array(other) - ser = self._native_series - return self._from_native_series(pc.is_in(ser, value_set=value_set)) + return self._from_native_series(pc.is_in(self.native, value_set=value_set)) def arg_true(self: Self) -> Self: import numpy as np # ignore-banned-import - ser = self._native_series - res = np.flatnonzero(ser) + res = np.flatnonzero(self.native) return self._from_iterable( res, name=self.name, @@ -583,8 +571,8 @@ def item(self: Self, index: int | None = None) -> Any: f" or an explicit index is provided (Series is of length {len(self)})" ) raise ValueError(msg) - return maybe_extract_py_scalar(self._native_series[0], return_py_scalar=True) - return maybe_extract_py_scalar(self._native_series[index], return_py_scalar=True) + return maybe_extract_py_scalar(self.native[0], return_py_scalar=True) + return maybe_extract_py_scalar(self.native[index], return_py_scalar=True) def value_counts( self: Self, @@ -600,7 +588,7 @@ def value_counts( index_name_ = "index" if self._name is None else self._name value_name_ = name or ("proportion" if normalize else "count") - val_counts = pc.value_counts(self._native_series) + val_counts = pc.value_counts(self.native) values = val_counts.field("values") counts = cast("ArrowChunkedArray", val_counts.field("counts")) @@ -622,10 +610,8 @@ def value_counts( ) def zip_with(self: Self, mask: Self, other: Self) -> Self: - cond = mask._native_series.combine_chunks() - return self._from_native_series( - pc.if_else(cond, self._native_series, other._native_series) - ) + cond = mask.native.combine_chunks() + return self._from_native_series(pc.if_else(cond, self.native, other.native)) def sample( self: Self, @@ -637,16 +623,14 @@ def sample( ) -> Self: import numpy as np # ignore-banned-import - ser = self._native_series num_rows = len(self) - if n is None and fraction is not None: n = int(num_rows * fraction) rng = np.random.default_rng(seed=seed) idx = np.arange(0, num_rows) mask = rng.choice(idx, size=n, replace=with_replacement) - return self._from_native_series(ser.take(mask)) # pyright: ignore[reportArgumentType] + return self._from_native_series(self.native.take(mask)) # pyright: ignore[reportArgumentType] def fill_null( self: Self, @@ -680,25 +664,22 @@ def fill_aux( arr, ) - ser = self._native_series - if value is not None: _, value = extract_native(self, value) - res_ser = self._from_native_series(pc.fill_null(ser, value)) # type: ignore[attr-defined] + series = pc.fill_null(self.native, value) # type: ignore[attr-defined] elif limit is None: fill_func = ( pc.fill_null_forward if strategy == "forward" else pc.fill_null_backward ) - res_ser = self._from_native_series(fill_func(ser)) + series = fill_func(self.native) else: - res_ser = self._from_native_series(fill_aux(ser, limit, strategy)) - - return res_ser + series = fill_aux(self.native, limit, strategy) + return self._from_native_series(series) def to_frame(self: Self) -> ArrowDataFrame: from narwhals._arrow.dataframe import ArrowDataFrame - df = pa.Table.from_arrays([self._native_series], names=[self.name]) + df = pa.Table.from_arrays([self.native], names=[self.name]) return ArrowDataFrame( df, backend_version=self._backend_version, @@ -709,12 +690,12 @@ def to_frame(self: Self) -> ArrowDataFrame: def to_pandas(self: Self) -> pd.Series[Any]: import pandas as pd # ignore-banned-import() - return pd.Series(self._native_series, name=self.name) # pyright: ignore[reportArgumentType, reportCallIssue] + return pd.Series(self.native, name=self.name) # pyright: ignore[reportArgumentType, reportCallIssue] def to_polars(self: Self) -> pl.Series: import polars as pl # ignore-banned-import - return pl.from_arrow(self._native_series) # type: ignore[return-value] + return pl.from_arrow(self.native) # type: ignore[return-value] def is_unique(self: Self) -> Self: return self.to_frame().is_unique().alias(self.name) # type: ignore[return-value] @@ -725,7 +706,7 @@ def is_first_distinct(self: Self) -> Self: row_number = pa.array(np.arange(len(self))) col_token = generate_temporary_column_name(n_bytes=8, columns=[self.name]) first_distinct_index = ( - pa.Table.from_arrays([self._native_series], names=[self.name]) + pa.Table.from_arrays([self.native], names=[self.name]) .append_column(col_token, row_number) .group_by(self.name) .aggregate([(col_token, "min")]) @@ -740,7 +721,7 @@ def is_last_distinct(self: Self) -> Self: row_number = pa.array(np.arange(len(self))) col_token = generate_temporary_column_name(n_bytes=8, columns=[self.name]) last_distinct_index = ( - pa.Table.from_arrays([self._native_series], names=[self.name]) + pa.Table.from_arrays([self.native], names=[self.name]) .append_column(col_token, row_number) .group_by(self.name) .aggregate([(col_token, "max")]) @@ -753,23 +734,21 @@ def is_sorted(self: Self, *, descending: bool) -> bool: if not isinstance(descending, bool): msg = f"argument 'descending' should be boolean, found {type(descending)}" raise TypeError(msg) - - ser = self._native_series if descending: - result = pc.all(pc.greater_equal(ser[:-1], ser[1:])) + result = pc.all(pc.greater_equal(self.native[:-1], self.native[1:])) else: - result = pc.all(pc.less_equal(ser[:-1], ser[1:])) + result = pc.all(pc.less_equal(self.native[:-1], self.native[1:])) return maybe_extract_py_scalar(result, return_py_scalar=True) def unique(self: Self, *, maintain_order: bool) -> Self: # TODO(marco): `pc.unique` seems to always maintain order, is that guaranteed? - return self._from_native_series(self._native_series.unique()) + return self._from_native_series(self.native.unique()) def replace_strict( self: Self, old: Sequence[Any], new: Sequence[Any], *, return_dtype: DType | None ) -> Self: # https://stackoverflow.com/a/79111029/4451315 - idxs = pc.index_in(self._native_series, pa.array(old)) + idxs = pc.index_in(self.native, pa.array(old)) result_native = pc.take(pa.array(new), idxs) if return_dtype is not None: result_native.cast(narwhals_to_native_dtype(return_dtype, self._version)) @@ -784,23 +763,21 @@ def replace_strict( return result def sort(self: Self, *, descending: bool, nulls_last: bool) -> Self: - series = self._native_series order: Order = "descending" if descending else "ascending" null_placement: NullPlacement = "at_end" if nulls_last else "at_start" sorted_indices = pc.array_sort_indices( - series, order=order, null_placement=null_placement + self.native, order=order, null_placement=null_placement ) - return self._from_native_series(series.take(sorted_indices)) + return self._from_native_series(self.native.take(sorted_indices)) def to_dummies(self: Self, *, separator: str, drop_first: bool) -> ArrowDataFrame: import numpy as np # ignore-banned-import from narwhals._arrow.dataframe import ArrowDataFrame - series = self._native_series name = self._name # NOTE: stub is missing attributes (https://arrow.apache.org/docs/python/generated/pyarrow.DictionaryArray.html) - da: Incomplete = series.combine_chunks().dictionary_encode(null_encoding="encode") + da: Incomplete = self.native.combine_chunks().dictionary_encode("encode") columns: _2DArray = np.zeros((len(da.dictionary), len(da)), np.int8) columns[da.indices, np.arange(len(da))] = 1 @@ -835,29 +812,28 @@ def quantile( _return_py_scalar: bool = True, ) -> float: return maybe_extract_py_scalar( - pc.quantile(self._native_series, q=quantile, interpolation=interpolation)[0], + pc.quantile(self.native, q=quantile, interpolation=interpolation)[0], _return_py_scalar, ) def gather_every(self: Self, n: int, offset: int = 0) -> Self: - return self._from_native_series(self._native_series[offset::n]) + return self._from_native_series(self.native[offset::n]) def clip( self: Self, lower_bound: Self | Any | None, upper_bound: Self | Any | None ) -> Self: - arr = self._native_series _, lower_bound = extract_native(self, lower_bound) _, upper_bound = extract_native(self, upper_bound) # NOTE: stubs are missing `ChunkedArray` support # https://github.com/zen-xu/pyarrow-stubs/blob/d97063876720e6a5edda7eb15f4efe07c31b8296/pyarrow-stubs/compute.pyi#L948-L954 max_element_wise: Incomplete = pc.max_element_wise - arr = max_element_wise(arr, lower_bound) + arr = max_element_wise(self.native, lower_bound) arr = cast("ArrowChunkedArray", pc.min_element_wise(arr, upper_bound)) return self._from_native_series(arr) def to_arrow(self: Self) -> ArrowArray: - return self._native_series.combine_chunks() + return self.native.combine_chunks() def mode(self: Self) -> Self: plx = self.__narwhals_namespace__() @@ -870,7 +846,7 @@ def mode(self: Self) -> Self: ).filter(plx.col(col_token) == plx.col(col_token).max())[self.name] def is_finite(self: Self) -> Self: - return self._from_native_series(pc.is_finite(self._native_series)) + return self._from_native_series(pc.is_finite(self.native)) def cum_count(self: Self, *, reverse: bool) -> Self: dtypes = import_dtypes_module(self._version) @@ -881,7 +857,7 @@ def cum_min(self: Self, *, reverse: bool) -> Self: msg = "cum_min method is not supported for pyarrow < 13.0.0" raise NotImplementedError(msg) - native_series = cast("Any", self._native_series) + native_series = cast("Any", self.native) result = ( pc.cumulative_min(native_series, skip_nulls=True) @@ -895,7 +871,7 @@ def cum_max(self: Self, *, reverse: bool) -> Self: msg = "cum_max method is not supported for pyarrow < 13.0.0" raise NotImplementedError(msg) - native_series = cast("Any", self._native_series) + native_series = cast("Any", self.native) result = ( pc.cumulative_max(native_series, skip_nulls=True) @@ -909,7 +885,7 @@ def cum_prod(self: Self, *, reverse: bool) -> Self: msg = "cum_max method is not supported for pyarrow < 13.0.0" raise NotImplementedError(msg) - native_series = cast("Any", self._native_series) + native_series = cast("Any", self.native) result = ( pc.cumulative_prod(native_series, skip_nulls=True) @@ -1077,9 +1053,9 @@ def rank( native_series: ArrowChunkedArray | ArrowArray if self._backend_version < (14, 0, 0): # pragma: no cover - native_series = self._native_series.combine_chunks() + native_series = self.native.combine_chunks() else: - native_series = self._native_series + native_series = self.native null_mask = pc.is_null(native_series) @@ -1103,7 +1079,7 @@ def hist( # noqa: PLR0915 from narwhals._arrow.dataframe import ArrowDataFrame def _hist_from_bin_count(bin_count: int): # type: ignore[no-untyped-def] # noqa: ANN202 - d = pc.min_max(self._native_series) + d = pc.min_max(self.native) lower, upper = d["min"], d["max"] pa_float = pa.type_for_alias("float") if lower == upper: @@ -1117,9 +1093,7 @@ def _hist_from_bin_count(bin_count: int): # type: ignore[no-untyped-def] # noqa width = pc.divide(pc.cast(range_, pa_float), lit(float(bin_count))) bin_proportions = pc.divide( - pc.subtract( - cast("pc.NumericOrTemporalArray", self._native_series), lower - ), + pc.subtract(cast("pc.NumericOrTemporalArray", self.native), lower), width, ) bin_indices: ArrowChunkedArray = cast( @@ -1166,7 +1140,7 @@ def _hist_from_bin_count(bin_count: int): # type: ignore[no-untyped-def] # noqa return counts.column("counts"), bin_right def _hist_from_bins(bins: Sequence[int | float]): # type: ignore[no-untyped-def] # noqa: ANN202 - bin_indices = np.searchsorted(bins, self._native_series, side="left") + bin_indices = np.searchsorted(bins, self.native, side="left") obs_cats, obs_counts = np.unique(bin_indices, return_counts=True) obj_cats = np.arange(1, len(bins)) counts = np.zeros_like(obj_cats) @@ -1205,10 +1179,8 @@ def _hist_from_bins(bins: Sequence[int | float]): # type: ignore[no-untyped-def ) def __iter__(self: Self) -> Iterator[Any]: - yield from ( - maybe_extract_py_scalar(x, return_py_scalar=True) - for x in self._native_series.__iter__() - ) + for x in self.native: + yield maybe_extract_py_scalar(x, return_py_scalar=True) def __contains__(self: Self, other: Any) -> bool: from pyarrow import ArrowInvalid # ignore-banned-imports @@ -1216,12 +1188,9 @@ def __contains__(self: Self, other: Any) -> bool: from pyarrow import ArrowTypeError # ignore-banned-imports try: - native_series = self._native_series - other_ = ( - lit(other) if other is not None else lit(None, type=native_series.type) - ) + other_ = lit(other) if other is not None else lit(None, type=self._type) return maybe_extract_py_scalar( - pc.is_in(other_, native_series), return_py_scalar=True + pc.is_in(other_, self.native), return_py_scalar=True ) except (ArrowInvalid, ArrowNotImplementedError, ArrowTypeError) as exc: from narwhals.exceptions import InvalidOperationError diff --git a/narwhals/_arrow/series_cat.py b/narwhals/_arrow/series_cat.py index 730903427d..af963e1716 100644 --- a/narwhals/_arrow/series_cat.py +++ b/narwhals/_arrow/series_cat.py @@ -4,6 +4,8 @@ import pyarrow as pa +from narwhals._arrow.utils import ArrowSeriesNamespace + if TYPE_CHECKING: from typing_extensions import Self @@ -11,13 +13,8 @@ from narwhals._arrow.typing import Incomplete -class ArrowSeriesCatNamespace: - def __init__(self: Self, series: ArrowSeries) -> None: - self._compliant_series: ArrowSeries = series - +class ArrowSeriesCatNamespace(ArrowSeriesNamespace): def get_categories(self: Self) -> ArrowSeries: # NOTE: Should be `list[pa.DictionaryArray]`, but `DictionaryArray` has no attributes - chunks: Incomplete = self._compliant_series._native_series.chunks - return self._compliant_series._from_native_series( - pa.concat_arrays(x.dictionary for x in chunks).unique() - ) + chunks: Incomplete = self.native.chunks + return self.from_native(pa.concat_arrays(x.dictionary for x in chunks).unique()) diff --git a/narwhals/_arrow/series_dt.py b/narwhals/_arrow/series_dt.py index 6b45a0faf5..23dcab43a9 100644 --- a/narwhals/_arrow/series_dt.py +++ b/narwhals/_arrow/series_dt.py @@ -1,11 +1,13 @@ from __future__ import annotations from typing import TYPE_CHECKING +from typing import Any from typing import cast import pyarrow as pa import pyarrow.compute as pc +from narwhals._arrow.utils import ArrowSeriesNamespace from narwhals._arrow.utils import floordiv_compat from narwhals._arrow.utils import lit from narwhals.utils import import_dtypes_module @@ -15,45 +17,43 @@ from narwhals._arrow.series import ArrowSeries from narwhals._arrow.typing import ArrowChunkedArray + from narwhals.dtypes import Datetime from narwhals.typing import TimeUnit -class ArrowSeriesDateTimeNamespace: - def __init__(self: Self, series: ArrowSeries) -> None: - self._compliant_series: ArrowSeries = series +class ArrowSeriesDateTimeNamespace(ArrowSeriesNamespace): + @property + def unit(self) -> TimeUnit: # NOTE: Unsafe (native). + return cast("pa.TimestampType[TimeUnit, Any]", self.native.type).unit + + @property + def time_zone(self) -> str | None: # NOTE: Unsafe (narwhals). + return cast("Datetime", self.compliant.dtype).time_zone def to_string(self: Self, format: str) -> ArrowSeries: # noqa: A002 # PyArrow differs from other libraries in that %S also prints out # the fractional part of the second...:'( # https://arrow.apache.org/docs/python/generated/pyarrow.compute.strftime.html format = format.replace("%S.%f", "%S").replace("%S%.f", "%S") - return self._compliant_series._from_native_series( - pc.strftime(self._compliant_series._native_series, format) - ) + return self.from_native(pc.strftime(self.native, format)) def replace_time_zone(self: Self, time_zone: str | None) -> ArrowSeries: - ser: ArrowSeries = self._compliant_series if time_zone is not None: - result = pc.assume_timezone(pc.local_timestamp(ser._native_series), time_zone) + result = pc.assume_timezone(pc.local_timestamp(self.native), time_zone) else: - result = pc.local_timestamp(ser._native_series) - return self._compliant_series._from_native_series(result) + result = pc.local_timestamp(self.native) + return self.from_native(result) def convert_time_zone(self: Self, time_zone: str) -> ArrowSeries: - if self._compliant_series.dtype.time_zone is None: # type: ignore[attr-defined] - ser: ArrowSeries = self.replace_time_zone("UTC") - else: - ser = self._compliant_series - native_type = pa.timestamp(ser._type.unit, time_zone) # type: ignore[attr-defined] - result = ser._native_series.cast(native_type) - return self._compliant_series._from_native_series(result) + ser = self.replace_time_zone("UTC") if self.time_zone is None else self.compliant + return self.from_native(ser.native.cast(pa.timestamp(self.unit, time_zone))) def timestamp(self: Self, time_unit: TimeUnit) -> ArrowSeries: - ser: ArrowSeries = self._compliant_series + ser: ArrowSeries = self.compliant dtypes = import_dtypes_module(ser._version) if isinstance(ser.dtype, dtypes.Datetime): unit = ser.dtype.time_unit - s_cast = ser._native_series.cast(pa.int64()) + s_cast = self.native.cast(pa.int64()) if unit == "ns": if time_unit == "ns": result = s_cast @@ -86,7 +86,7 @@ def timestamp(self: Self, time_unit: TimeUnit) -> ArrowSeries: msg = f"unexpected time unit {unit}, please report an issue at https://github.com/narwhals-dev/narwhals" raise AssertionError(msg) elif isinstance(ser.dtype, dtypes.Date): - time_s = pc.multiply(ser._native_series.cast(pa.int32()), 86400) + time_s = pc.multiply(self.native.cast(pa.int32()), 86400) if time_unit == "ns": result = cast("ArrowChunkedArray", pc.multiply(time_s, 1_000_000_000)) elif time_unit == "us": @@ -96,148 +96,99 @@ def timestamp(self: Self, time_unit: TimeUnit) -> ArrowSeries: else: msg = "Input should be either of Date or Datetime type" raise TypeError(msg) - return self._compliant_series._from_native_series(result) + return self.from_native(result) def date(self: Self) -> ArrowSeries: - return self._compliant_series._from_native_series( - self._compliant_series._native_series.cast(pa.date32()) - ) + return self.from_native(self.native.cast(pa.date32())) def year(self: Self) -> ArrowSeries: - return self._compliant_series._from_native_series( - pc.year(self._compliant_series._native_series) - ) + return self.from_native(pc.year(self.native)) def month(self: Self) -> ArrowSeries: - return self._compliant_series._from_native_series( - pc.month(self._compliant_series._native_series) - ) + return self.from_native(pc.month(self.native)) def day(self: Self) -> ArrowSeries: - return self._compliant_series._from_native_series( - pc.day(self._compliant_series._native_series) - ) + return self.from_native(pc.day(self.native)) def hour(self: Self) -> ArrowSeries: - return self._compliant_series._from_native_series( - pc.hour(self._compliant_series._native_series) - ) + return self.from_native(pc.hour(self.native)) def minute(self: Self) -> ArrowSeries: - return self._compliant_series._from_native_series( - pc.minute(self._compliant_series._native_series) - ) + return self.from_native(pc.minute(self.native)) def second(self: Self) -> ArrowSeries: - return self._compliant_series._from_native_series( - pc.second(self._compliant_series._native_series) - ) + return self.from_native(pc.second(self.native)) def millisecond(self: Self) -> ArrowSeries: - return self._compliant_series._from_native_series( - pc.millisecond(self._compliant_series._native_series) - ) + return self.from_native(pc.millisecond(self.native)) def microsecond(self: Self) -> ArrowSeries: - ser: ArrowSeries = self._compliant_series - arr = ser._native_series + arr = self.native result = pc.add(pc.multiply(pc.millisecond(arr), lit(1000)), pc.microsecond(arr)) - return self._compliant_series._from_native_series(result) + return self.from_native(result) def nanosecond(self: Self) -> ArrowSeries: - ser: ArrowSeries = self._compliant_series result = pc.add( - pc.multiply(self.microsecond()._native_series, lit(1000)), - pc.nanosecond(ser._native_series), + pc.multiply(self.microsecond().native, lit(1000)), pc.nanosecond(self.native) ) - return self._compliant_series._from_native_series(result) + return self.from_native(result) def ordinal_day(self: Self) -> ArrowSeries: - return self._compliant_series._from_native_series( - pc.day_of_year(self._compliant_series._native_series) - ) + return self.from_native(pc.day_of_year(self.native)) def weekday(self: Self) -> ArrowSeries: - return self._compliant_series._from_native_series( - pc.day_of_week(self._compliant_series._native_series, count_from_zero=False) - ) + return self.from_native(pc.day_of_week(self.native, count_from_zero=False)) def total_minutes(self: Self) -> ArrowSeries: - ser: ArrowSeries = self._compliant_series unit_to_minutes_factor = { "s": 60, # seconds "ms": 60 * 1e3, # milli "us": 60 * 1e6, # micro "ns": 60 * 1e9, # nano } - unit = ser._type.unit # type: ignore[attr-defined] - factor = lit(unit_to_minutes_factor[unit], type=pa.int64()) - return self._compliant_series._from_native_series( - pc.cast(pc.divide(ser._native_series, factor), pa.int64()) - ) + factor = lit(unit_to_minutes_factor[self.unit], type=pa.int64()) + return self.from_native(pc.divide(self.native, factor).cast(pa.int64())) def total_seconds(self: Self) -> ArrowSeries: - ser: ArrowSeries = self._compliant_series unit_to_seconds_factor = { "s": 1, # seconds "ms": 1e3, # milli "us": 1e6, # micro "ns": 1e9, # nano } - unit = ser._type.unit # type: ignore[attr-defined] - factor = lit(unit_to_seconds_factor[unit], type=pa.int64()) - return self._compliant_series._from_native_series( - pc.cast(pc.divide(ser._native_series, factor), pa.int64()) - ) + factor = lit(unit_to_seconds_factor[self.unit], type=pa.int64()) + return self.from_native(pc.divide(self.native, factor).cast(pa.int64())) def total_milliseconds(self: Self) -> ArrowSeries: - ser: ArrowSeries = self._compliant_series - arr = ser._native_series - unit = ser._type.unit # type: ignore[attr-defined] unit_to_milli_factor = { "s": 1e3, # seconds "ms": 1, # milli "us": 1e3, # micro "ns": 1e6, # nano } - factor = lit(unit_to_milli_factor[unit], type=pa.int64()) - if unit == "s": - return self._compliant_series._from_native_series( - pc.cast(pc.multiply(arr, factor), pa.int64()) - ) - return self._compliant_series._from_native_series( - pc.cast(pc.divide(arr, factor), pa.int64()) - ) + factor = lit(unit_to_milli_factor[self.unit], type=pa.int64()) + if self.unit == "s": + return self.from_native(pc.multiply(self.native, factor).cast(pa.int64())) + return self.from_native(pc.divide(self.native, factor).cast(pa.int64())) def total_microseconds(self: Self) -> ArrowSeries: - ser: ArrowSeries = self._compliant_series - arr = ser._native_series - unit = ser._type.unit # type: ignore[attr-defined] unit_to_micro_factor = { "s": 1e6, # seconds "ms": 1e3, # milli "us": 1, # micro "ns": 1e3, # nano } - factor = lit(unit_to_micro_factor[unit], type=pa.int64()) - if unit in {"s", "ms"}: - return self._compliant_series._from_native_series( - pc.cast(pc.multiply(arr, factor), pa.int64()) - ) - return self._compliant_series._from_native_series( - pc.cast(pc.divide(arr, factor), pa.int64()) - ) + factor = lit(unit_to_micro_factor[self.unit], type=pa.int64()) + if self.unit in {"s", "ms"}: + return self.from_native(pc.multiply(self.native, factor).cast(pa.int64())) + return self.from_native(pc.divide(self.native, factor).cast(pa.int64())) def total_nanoseconds(self: Self) -> ArrowSeries: - ser: ArrowSeries = self._compliant_series unit_to_nano_factor = { "s": 1e9, # seconds "ms": 1e6, # milli "us": 1e3, # micro "ns": 1, # nano } - unit = ser._type.unit # type: ignore[attr-defined] - factor = lit(unit_to_nano_factor[unit], type=pa.int64()) - return self._compliant_series._from_native_series( - pc.cast(pc.multiply(ser._native_series, factor), pa.int64()) - ) + factor = lit(unit_to_nano_factor[self.unit], type=pa.int64()) + return self.from_native(pc.multiply(self.native, factor).cast(pa.int64())) diff --git a/narwhals/_arrow/series_list.py b/narwhals/_arrow/series_list.py index 05fa3b3f5e..4c13bdda4e 100644 --- a/narwhals/_arrow/series_list.py +++ b/narwhals/_arrow/series_list.py @@ -5,17 +5,14 @@ import pyarrow as pa import pyarrow.compute as pc +from narwhals._arrow.utils import ArrowSeriesNamespace + if TYPE_CHECKING: from typing_extensions import Self from narwhals._arrow.series import ArrowSeries -class ArrowSeriesListNamespace: - def __init__(self: Self, series: ArrowSeries) -> None: - self._arrow_series: ArrowSeries = series - +class ArrowSeriesListNamespace(ArrowSeriesNamespace): def len(self: Self) -> ArrowSeries: - return self._arrow_series._from_native_series( - pc.cast(pc.list_value_length(self._arrow_series._native_series), pa.uint32()) - ) + return self.from_native(pc.list_value_length(self.native).cast(pa.uint32())) diff --git a/narwhals/_arrow/series_str.py b/narwhals/_arrow/series_str.py index d44e54b262..0026d89ae4 100644 --- a/narwhals/_arrow/series_str.py +++ b/narwhals/_arrow/series_str.py @@ -6,6 +6,7 @@ import pyarrow.compute as pc +from narwhals._arrow.utils import ArrowSeriesNamespace from narwhals._arrow.utils import lit from narwhals._arrow.utils import parse_datetime_format @@ -17,22 +18,16 @@ from narwhals._arrow.typing import Incomplete -class ArrowSeriesStringNamespace: - def __init__(self: Self, series: ArrowSeries) -> None: - self._compliant_series: ArrowSeries = series - +class ArrowSeriesStringNamespace(ArrowSeriesNamespace): def len_chars(self: Self) -> ArrowSeries: - return self._compliant_series._from_native_series( - pc.utf8_length(self._compliant_series._native_series) - ) + return self.from_native(pc.utf8_length(self.native)) def replace( self: Self, pattern: str, value: str, *, literal: bool, n: int ) -> ArrowSeries: - compliant = self._compliant_series fn = pc.replace_substring if literal else pc.replace_substring_regex - arr = fn(compliant._native_series, pattern, replacement=value, max_replacements=n) - return compliant._from_native_series(arr) + arr = fn(self.native, pattern, replacement=value, max_replacements=n) + return self.from_native(arr) def replace_all( self: Self, pattern: str, value: str, *, literal: bool @@ -40,57 +35,42 @@ def replace_all( return self.replace(pattern, value, literal=literal, n=-1) def strip_chars(self: Self, characters: str | None) -> ArrowSeries: - whitespace = string.whitespace - return self._compliant_series._from_native_series( - pc.utf8_trim( - self._compliant_series._native_series, - characters or whitespace, - ) + return self.from_native( + pc.utf8_trim(self.native, characters or string.whitespace) ) def starts_with(self: Self, prefix: str) -> ArrowSeries: - return self._compliant_series._from_native_series( - pc.equal(self.slice(0, len(prefix))._native_series, lit(prefix)) - ) + return self.from_native(pc.equal(self.slice(0, len(prefix)).native, lit(prefix))) def ends_with(self: Self, suffix: str) -> ArrowSeries: - return self._compliant_series._from_native_series( - pc.equal(self.slice(-len(suffix), None)._native_series, lit(suffix)) + return self.from_native( + pc.equal(self.slice(-len(suffix), None).native, lit(suffix)) ) def contains(self: Self, pattern: str, *, literal: bool) -> ArrowSeries: check_func = pc.match_substring if literal else pc.match_substring_regex - return self._compliant_series._from_native_series( - check_func(self._compliant_series._native_series, pattern) - ) + return self.from_native(check_func(self.native, pattern)) def slice(self: Self, offset: int, length: int | None) -> ArrowSeries: stop = offset + length if length is not None else None - return self._compliant_series._from_native_series( - pc.utf8_slice_codeunits( - self._compliant_series._native_series, start=offset, stop=stop - ) + return self.from_native( + pc.utf8_slice_codeunits(self.native, start=offset, stop=stop) ) def split(self: Self, by: str) -> ArrowSeries: - split_series = pc.split_pattern(self._compliant_series._native_series, by) # type: ignore[call-overload] - return self._compliant_series._from_native_series(split_series) + split_series = pc.split_pattern(self.native, by) # type: ignore[call-overload] + return self.from_native(split_series) def to_datetime(self: Self, format: str | None) -> ArrowSeries: # noqa: A002 - native = self._compliant_series._native_series - format = parse_datetime_format(native) if format is None else format + format = parse_datetime_format(self.native) if format is None else format strptime: Incomplete = pc.strptime timestamp_array: pa.Array[pa.TimestampScalar[Any, Any]] = strptime( - native, format=format, unit="us" + self.native, format=format, unit="us" ) - return self._compliant_series._from_native_series(timestamp_array) + return self.from_native(timestamp_array) def to_uppercase(self: Self) -> ArrowSeries: - return self._compliant_series._from_native_series( - pc.utf8_upper(self._compliant_series._native_series), - ) + return self.from_native(pc.utf8_upper(self.native)) def to_lowercase(self: Self) -> ArrowSeries: - return self._compliant_series._from_native_series( - pc.utf8_lower(self._compliant_series._native_series), - ) + return self.from_native(pc.utf8_lower(self.native)) diff --git a/narwhals/_arrow/utils.py b/narwhals/_arrow/utils.py index 577d9b2168..bfc617cee9 100644 --- a/narwhals/_arrow/utils.py +++ b/narwhals/_arrow/utils.py @@ -12,15 +12,19 @@ import pyarrow as pa import pyarrow.compute as pc +from narwhals.utils import _ExprNamespace +from narwhals.utils import _SeriesNamespace from narwhals.utils import import_dtypes_module from narwhals.utils import isinstance_or_issubclass if TYPE_CHECKING: from typing import TypeVar + from typing_extensions import Self from typing_extensions import TypeAlias from typing_extensions import TypeIs + from narwhals._arrow.expr import ArrowExpr from narwhals._arrow.series import ArrowSeries from narwhals._arrow.typing import ArrowArray from narwhals._arrow.typing import ArrowChunkedArray @@ -225,22 +229,22 @@ def extract_native( from narwhals._arrow.series import ArrowSeries if rhs is None: - return lhs._native_series, lit(None, type=lhs._native_series.type) + return lhs.native, lit(None, type=lhs._type) if isinstance(rhs, ArrowDataFrame): return NotImplemented if isinstance(rhs, ArrowSeries): if lhs._broadcast and not rhs._broadcast: - return lhs._native_series[0], rhs._native_series + return lhs.native[0], rhs.native if rhs._broadcast: - return lhs._native_series, rhs._native_series[0] - return lhs._native_series, rhs._native_series + return lhs.native, rhs.native[0] + return lhs.native, rhs.native if isinstance(rhs, list): msg = "Expected Series or scalar, got list." raise TypeError(msg) - return lhs._native_series, rhs + return lhs.native, rhs def align_series_full_broadcast(*series: ArrowSeries) -> Sequence[ArrowSeries]: @@ -255,13 +259,12 @@ def align_series_full_broadcast(*series: ArrowSeries) -> Sequence[ArrowSeries]: is_max_length_gt_1 = max_length > 1 reshaped = [] for s, length in zip(series, lengths): - s_native = s._native_series if is_max_length_gt_1 and length == 1: - value = s_native[0] + value = s.native[0] if s._backend_version < (13,) and hasattr(value, "as_py"): value = value.as_py() reshaped.append( - s._from_native_series(pa.array([value] * max_length, type=s_native.type)) + s._from_native_series(pa.array([value] * max_length, type=s._type)) ) else: reshaped.append(s) @@ -275,17 +278,15 @@ def extract_dataframe_comparand( backend_version: tuple[int, ...], ) -> ArrowChunkedArray: """Extract native Series, broadcasting to `length` if necessary.""" - import numpy as np # ignore-banned-import - - if other._broadcast: - import numpy as np # ignore-banned-import + if not other._broadcast: + return other.native - value = other._native_series[0] - if backend_version < (13,) and hasattr(value, "as_py"): - value = value.as_py() - return pa.chunked_array([np.full(shape=length, fill_value=value)]) + import numpy as np # ignore-banned-import - return other._native_series + value = other.native[0] + if backend_version < (13,) and hasattr(value, "as_py"): + value = value.as_py() + return pa.chunked_array([np.full(shape=length, fill_value=value)]) def horizontal_concat(dfs: list[pa.Table]) -> pa.Table: @@ -533,22 +534,22 @@ def pad_series( Returns: A tuple containing the padded ArrowSeries and the offset value. """ - if center: - offset_left = window_size // 2 - offset_right = offset_left - ( - window_size % 2 == 0 - ) # subtract one if window_size is even - - native_series = series._native_series - - pad_left = pa.array([None] * offset_left, type=native_series.type) - pad_right = pa.array([None] * offset_right, type=native_series.type) - padded_arr = series._from_native_series( - pa.concat_arrays([pad_left, *native_series.chunks, pad_right]) - ) - offset = offset_left + offset_right - else: - padded_arr = series - offset = 0 + if not center: + return series, 0 + offset_left = window_size // 2 + # subtract one if window_size is even + offset_right = offset_left - (window_size % 2 == 0) + pad_left = pa.array([None] * offset_left, type=series._type) + pad_right = pa.array([None] * offset_right, type=series._type) + concat = pa.concat_arrays([pad_left, *series.native.chunks, pad_right]) + return series._from_native_series(concat), offset_left + offset_right + + +class ArrowSeriesNamespace(_SeriesNamespace["ArrowSeries", "ArrowChunkedArray"]): + def __init__(self: Self, series: ArrowSeries, /) -> None: + self._compliant_series = series + - return padded_arr, offset +class ArrowExprNamespace(_ExprNamespace["ArrowExpr"]): + def __init__(self: Self, expr: ArrowExpr, /) -> None: + self._compliant_expr = expr diff --git a/narwhals/_pandas_like/series.py b/narwhals/_pandas_like/series.py index 800081f176..b80f2859c2 100644 --- a/narwhals/_pandas_like/series.py +++ b/narwhals/_pandas_like/series.py @@ -177,7 +177,7 @@ def _from_iterable( ) def __len__(self: Self) -> int: - return len(self._native_series) + return len(self.native) @property def name(self: Self) -> str: @@ -185,15 +185,19 @@ def name(self: Self) -> str: @property def dtype(self: Self) -> DType: - native_dtype = self._native_series.dtype + native_dtype = self.native.dtype return ( native_to_narwhals_dtype(native_dtype, self._version, self._implementation) if native_dtype != "object" else object_native_to_narwhals_dtype( - self._native_series, self._version, self._implementation + self.native, self._version, self._implementation ) ) + @property + def native(self) -> Any: + return self._native_series + def ewm_mean( self: Self, *, diff --git a/narwhals/_polars/series.py b/narwhals/_polars/series.py index 73133a3bea..23e9ab6674 100644 --- a/narwhals/_polars/series.py +++ b/narwhals/_polars/series.py @@ -115,6 +115,10 @@ def dtype(self: Self) -> DType: self._native_series.dtype, self._version, self._backend_version ) + @property + def native(self) -> pl.Series: + return self._native_series + def alias(self, name: str) -> Self: return self._from_native_object(self._native_series.alias(name)) @@ -130,9 +134,8 @@ def __getitem__( return self._from_native_object(self._native_series.__getitem__(item)) def cast(self: Self, dtype: DType) -> Self: - ser = self._native_series dtype_pl = narwhals_to_native_dtype(dtype, self._version, self._backend_version) - return self._from_native_series(ser.cast(dtype_pl)) + return self._from_native_series(self.native.cast(dtype_pl)) def replace_strict( self: Self, old: Sequence[Any], new: Sequence[Any], *, return_dtype: DType | None diff --git a/narwhals/typing.py b/narwhals/typing.py index 9b36394571..7ec2179a3a 100644 --- a/narwhals/typing.py +++ b/narwhals/typing.py @@ -67,8 +67,11 @@ class CompliantSeries(Protocol): def dtype(self) -> DType: ... @property def name(self) -> str: ... + @property + def native(self) -> Any: ... def __narwhals_series__(self) -> CompliantSeries: ... def alias(self, name: str) -> Self: ... + def _from_native_series(self, series: Any) -> Self: ... CompliantSeriesT_co = TypeVar( diff --git a/narwhals/utils.py b/narwhals/utils.py index 43844171b8..1d384cc1da 100644 --- a/narwhals/utils.py +++ b/narwhals/utils.py @@ -13,6 +13,7 @@ from typing import Container from typing import Iterable from typing import Literal +from typing import Protocol from typing import Sequence from typing import TypeVar from typing import Union @@ -45,7 +46,6 @@ if TYPE_CHECKING: from types import ModuleType from typing import AbstractSet as Set - from typing import Protocol import pandas as pd from typing_extensions import Self @@ -61,7 +61,6 @@ from narwhals.typing import CompliantFrameT from narwhals.typing import CompliantLazyFrame from narwhals.typing import CompliantSeries - from narwhals.typing import CompliantSeriesT_co from narwhals.typing import DataFrameLike from narwhals.typing import DTypes from narwhals.typing import IntoSeriesT @@ -118,6 +117,75 @@ class _StoresColumns(Protocol): def columns(self) -> Sequence[str]: ... +NativeT_co = TypeVar("NativeT_co", covariant=True) +CompliantT_co = TypeVar("CompliantT_co", covariant=True) +CompliantExprT_co = TypeVar( + "CompliantExprT_co", bound="CompliantExpr[Any, Any]", covariant=True +) +CompliantSeriesT_co = TypeVar( + "CompliantSeriesT_co", bound="CompliantSeries", covariant=True +) + + +class _StoresNative(Protocol[NativeT_co]): + """Provides access to a native object. + + Native objects have types like: + + >>> from pandas import Series + >>> from pyarrow import Table + """ + + @property + def native(self) -> NativeT_co: + """Return the native object.""" + ... + + +class _StoresCompliant(Protocol[CompliantT_co]): + """Provides access to a compliant object. + + Compliant objects have types like: + + >>> from narwhals._pandas_like.series import PandasLikeSeries + >>> from narwhals._arrow.dataframe import ArrowDataFrame + """ + + @property + def compliant(self) -> CompliantT_co: + """Return the compliant object.""" + ... + + +class _SeriesNamespace( # type: ignore[misc] # noqa: PYI046 + _StoresCompliant[CompliantSeriesT_co], + _StoresNative[NativeT_co], + Protocol[CompliantSeriesT_co, NativeT_co], +): + _compliant_series: CompliantSeriesT_co + + @property + def compliant(self) -> CompliantSeriesT_co: + return self._compliant_series + + @property + def native(self) -> NativeT_co: + return self._compliant_series.native + + def from_native(self, series: Any, /) -> CompliantSeriesT_co: + return self.compliant._from_native_series(series) + + +class _ExprNamespace( # type: ignore[misc] # noqa: PYI046 + _StoresCompliant[CompliantExprT_co], Protocol[CompliantExprT_co] +): + _compliant_expr: CompliantExprT_co + + @property + def compliant(self) -> CompliantExprT_co: + return self._compliant_expr + + class Version(Enum): V1 = auto() MAIN = auto()