diff --git a/narwhals/_arrow/dataframe.py b/narwhals/_arrow/dataframe.py index 0e8a21afc4..c5ddcd063b 100644 --- a/narwhals/_arrow/dataframe.py +++ b/narwhals/_arrow/dataframe.py @@ -1,5 +1,6 @@ from __future__ import annotations +from functools import partial from typing import TYPE_CHECKING from typing import Any from typing import Iterator @@ -36,10 +37,6 @@ import pandas as pd import polars as pl - from pyarrow._stubs_typing import ( # pyright: ignore[reportMissingModuleSource] - Indices, - ) - from pyarrow._stubs_typing import Order # pyright: ignore[reportMissingModuleSource] from typing_extensions import Self from typing_extensions import TypeAlias @@ -47,6 +44,10 @@ from narwhals._arrow.group_by import ArrowGroupBy from narwhals._arrow.namespace import ArrowNamespace from narwhals._arrow.series import ArrowSeries + from narwhals._arrow.typing import ArrowChunkedArray + from narwhals._arrow.typing import Indices + from narwhals._arrow.typing import Mask + from narwhals._arrow.typing import Order from narwhals.dtypes import DType from narwhals.typing import SizeUnit from narwhals.typing import _1DArray @@ -133,7 +134,7 @@ def __len__(self: Self) -> int: return len(self._native_frame) def row(self: Self, index: int) -> tuple[Any, ...]: - return tuple(col[index] for col in self._native_frame) + return tuple(col[index] for col in self._native_frame.itercolumns()) @overload def rows(self: Self, *, named: Literal[True]) -> list[dict[str, Any]]: ... @@ -371,7 +372,9 @@ def with_columns(self: Self, *exprs: ArrowExpr) -> Self: native_frame = ( native_frame.set_column( - columns.index(col_name), field_=col_name, column=column + columns.index(col_name), + field_=col_name, + column=column, # type: ignore[arg-type] ) if col_name in columns else native_frame.append_column(field_=col_name, column=column) @@ -532,9 +535,9 @@ def with_row_index(self: Self, name: str) -> Self: df.append_column(name, row_indices).select([name, *cols]) ) - def filter(self: Self, predicate: ArrowExpr | list[bool]) -> Self: + def filter(self: Self, predicate: ArrowExpr | list[bool | None]) -> Self: if isinstance(predicate, list): - mask_native = predicate + mask_native: Mask | ArrowChunkedArray = predicate else: # `[0]` is safe as the predicate's expression only returns a single column mask = evaluate_into_exprs(self, predicate)[0] @@ -542,7 +545,8 @@ def filter(self: Self, predicate: ArrowExpr | list[bool]) -> Self: length=len(self), other=mask, backend_version=self._backend_version ) return self._from_native_frame( - self._native_frame.filter(mask_native), validate_column_names=False + self._native_frame.filter(mask_native), # pyright: ignore[reportArgumentType] + validate_column_names=False, ) def head(self: Self, n: int) -> Self: @@ -745,17 +749,14 @@ def unique( agg_func = agg_func_map[keep] col_token = generate_temporary_column_name(n_bytes=8, columns=self.columns) - keep_idx = ( + keep_idx_native = ( df.append_column(col_token, pa.array(np.arange(len(self)))) .group_by(subset) .aggregate([(col_token, agg_func)]) .column(f"{col_token}_{agg_func}") ) - - return self._from_native_frame( - pc.take(df, keep_idx), # type: ignore[call-overload, unused-ignore] - validate_column_names=False, - ) + indices = cast("Indices", keep_idx_native) + return self._from_native_frame(df.take(indices), validate_column_names=False) keep_idx = self.simple_select(*subset).is_unique() plx = self.__narwhals_namespace__() @@ -804,21 +805,20 @@ def unpivot( on_: list[str] = ( [c for c in self.columns if c not in index_] if on is None else on ) - - promote_kwargs: dict[Literal["promote_options"], PromoteOptions] = ( - {"promote_options": "permissive"} + concat = ( + partial(pa.concat_tables, promote_options="permissive") if self._backend_version >= (14, 0, 0) - else {} + else pa.concat_tables ) names = [*index_, variable_name, value_name] return self._from_native_frame( - pa.concat_tables( + concat( [ pa.Table.from_arrays( [ *(native_frame.column(idx_col) for idx_col in index_), cast( - "pa.ChunkedArray", + "ArrowChunkedArray", pa.array([on_col] * n_rows, pa.string()), ), native_frame.column(on_col), @@ -826,8 +826,7 @@ def unpivot( names=names, ) for on_col in on_ - ], - **promote_kwargs, + ] ) ) # TODO(Unassigned): Even with promote_options="permissive", pyarrow does not diff --git a/narwhals/_arrow/expr.py b/narwhals/_arrow/expr.py index 89c48972e8..525cf0098c 100644 --- a/narwhals/_arrow/expr.py +++ b/narwhals/_arrow/expr.py @@ -48,7 +48,7 @@ def __init__( self._depth = depth self._function_name = function_name self._depth = depth - self._evaluate_output_names = evaluate_output_names + self._evaluate_output_names = evaluate_output_names # pyright: ignore[reportAttributeAccessIssue] self._alias_output_names = alias_output_names self._backend_version = backend_version self._version = version diff --git a/narwhals/_arrow/group_by.py b/narwhals/_arrow/group_by.py index 8284134038..32db3983f9 100644 --- a/narwhals/_arrow/group_by.py +++ b/narwhals/_arrow/group_by.py @@ -5,6 +5,7 @@ from typing import TYPE_CHECKING from typing import Any from typing import Iterator +from typing import cast import pyarrow as pa import pyarrow.compute as pc @@ -18,6 +19,7 @@ from narwhals._arrow.dataframe import ArrowDataFrame from narwhals._arrow.expr import ArrowExpr + from narwhals._arrow.typing import Incomplete POLARS_TO_ARROW_AGGREGATIONS = { "sum": "sum", @@ -68,7 +70,7 @@ def agg(self: Self, *exprs: ArrowExpr) -> ArrowDataFrame: ) raise ValueError(msg) - aggs: list[tuple[str, str, pc.FunctionOptions | None]] = [] + aggs: list[tuple[str, str, Any]] = [] expected_pyarrow_column_names: list[str] = self._keys.copy() new_column_names: list[str] = self._keys.copy() @@ -91,7 +93,7 @@ def agg(self: Self, *exprs: ArrowExpr) -> ArrowDataFrame: function_name = re.sub(r"(\w+->)", "", expr._function_name) if function_name in {"std", "var"}: - option = pc.VarianceOptions(ddof=expr._kwargs["ddof"]) + option: Any = pc.VarianceOptions(ddof=expr._kwargs["ddof"]) elif function_name in {"len", "n_unique"}: option = pc.CountOptions(mode="all") elif function_name == "count": @@ -139,14 +141,19 @@ def agg(self: Self, *exprs: ArrowExpr) -> ArrowDataFrame: def __iter__(self: Self) -> Iterator[tuple[Any, ArrowDataFrame]]: col_token = generate_temporary_column_name(n_bytes=8, columns=self._df.columns) - null_token = "__null_token_value__" # noqa: S105 + null_token: str = "__null_token_value__" # noqa: S105 table = self._df._native_frame - key_values = pc.binary_join_element_wise( - *[pc.cast(table[key], pa.string()) for key in self._keys], - "", - null_handling="replace", - null_replacement=null_token, + # NOTE: stubs fail in multiple places for `ChunkedArray` + it = cast( + "Iterator[pa.StringArray]", + (table[key].cast(pa.string()) for key in self._keys), + ) + # NOTE: stubs indicate `separator` must also be a `ChunkedArray` + # Reality: `str` is fine + concat_str: Incomplete = pc.binary_join_element_wise + key_values = concat_str( + *it, "", null_handling="replace", null_replacement=null_token ) table = table.add_column(i=0, field_=col_token, column=key_values) diff --git a/narwhals/_arrow/namespace.py b/narwhals/_arrow/namespace.py index 115f573d0c..5ce31083bc 100644 --- a/narwhals/_arrow/namespace.py +++ b/narwhals/_arrow/namespace.py @@ -19,6 +19,7 @@ from narwhals._arrow.utils import broadcast_series from narwhals._arrow.utils import diagonal_concat from narwhals._arrow.utils import horizontal_concat +from narwhals._arrow.utils import nulls_like from narwhals._arrow.utils import vertical_concat from narwhals._expression_parsing import combine_alias_output_names from narwhals._expression_parsing import combine_evaluate_output_names @@ -31,6 +32,7 @@ from typing_extensions import Self + from narwhals._arrow.typing import Incomplete from narwhals._arrow.typing import IntoArrowExpr from narwhals.dtypes import DType from narwhals.utils import Version @@ -254,13 +256,16 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]: def min_horizontal(self: Self, *exprs: ArrowExpr) -> ArrowExpr: def func(df: ArrowDataFrame) -> list[ArrowSeries]: init_series, *series = [s for _expr in exprs for s in _expr(df)] + # NOTE: Stubs copy the wrong signature https://github.com/zen-xu/pyarrow-stubs/blob/d97063876720e6a5edda7eb15f4efe07c31b8296/pyarrow-stubs/compute.pyi#L963 + min_element_wise: Incomplete = pc.min_element_wise + native_series = reduce( + min_element_wise, + [s._native_series for s in series], + init_series._native_series, + ) return [ ArrowSeries( - native_series=reduce( - pc.min_element_wise, - [s._native_series for s in series], - init_series._native_series, - ), + native_series, name=init_series.name, backend_version=self._backend_version, version=self._version, @@ -279,13 +284,17 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]: def max_horizontal(self: Self, *exprs: ArrowExpr) -> ArrowExpr: def func(df: ArrowDataFrame) -> list[ArrowSeries]: init_series, *series = [s for _expr in exprs for s in _expr(df)] + # NOTE: stubs are missing `ChunkedArray` support + # https://github.com/zen-xu/pyarrow-stubs/blob/d97063876720e6a5edda7eb15f4efe07c31b8296/pyarrow-stubs/compute.pyi#L948-L954 + max_element_wise: Incomplete = pc.max_element_wise + native_series = reduce( + max_element_wise, + [s._native_series for s in series], + init_series._native_series, + ) return [ ArrowSeries( - native_series=reduce( - pc.max_element_wise, - [s._native_series for s in series], - init_series._native_series, - ), + native_series, name=init_series.name, backend_version=self._backend_version, version=self._version, @@ -347,18 +356,19 @@ def concat_str( dtypes = import_dtypes_module(self._version) def func(df: ArrowDataFrame) -> list[ArrowSeries]: - compliant_series_list = [ + compliant_series_list: list[ArrowSeries] = [ s for _expr in exprs for s in _expr.cast(dtypes.String())(df) ] - null_handling = "skip" if ignore_nulls else "emit_null" - result_series = pc.binary_join_element_wise( - *(s._native_series for s in compliant_series_list), - separator, - null_handling=null_handling, + null_handling: Literal["skip", "emit_null"] = ( + "skip" if ignore_nulls else "emit_null" ) + it = (s._native_series for s in compliant_series_list) + # NOTE: stubs indicate `separator` must also be a `ChunkedArray` + # Reality: `str` is fine + concat_str: Incomplete = pc.binary_join_element_wise return [ ArrowSeries( - native_series=result_series, + native_series=concat_str(*it, separator, null_handling=null_handling), name=compliant_series_list[0].name, backend_version=self._backend_version, version=self._version, @@ -410,14 +420,11 @@ def __call__(self: Self, df: ArrowDataFrame) -> Sequence[ArrowSeries]: condition_native, value_series_native = broadcast_series( [condition, value_series] ) - if self._otherwise_value is None: - otherwise_native = pa.repeat( - pa.scalar(None, type=value_series_native.type), len(condition_native) - ) + otherwise_null = nulls_like(len(condition_native), value_series) return [ value_series._from_native_series( - pc.if_else(condition_native, value_series_native, otherwise_native) + pc.if_else(condition_native, value_series_native, otherwise_null) ) ] if isinstance(self._otherwise_value, ArrowExpr): @@ -474,7 +481,7 @@ def __init__( self._call = call self._depth = depth self._function_name = function_name - self._evaluate_output_names = evaluate_output_names + self._evaluate_output_names = evaluate_output_names # pyright: ignore[reportAttributeAccessIssue] self._alias_output_names = alias_output_names self._kwargs = kwargs diff --git a/narwhals/_arrow/series.py b/narwhals/_arrow/series.py index 27cac5ddd8..f2790a117b 100644 --- a/narwhals/_arrow/series.py +++ b/narwhals/_arrow/series.py @@ -6,6 +6,7 @@ from typing import Iterator from typing import Literal from typing import Sequence +from typing import cast from typing import overload import pyarrow as pa @@ -17,9 +18,12 @@ from narwhals._arrow.series_str import ArrowSeriesStringNamespace from narwhals._arrow.utils import broadcast_and_extract_native from narwhals._arrow.utils import cast_for_truediv +from narwhals._arrow.utils import chunked_array from narwhals._arrow.utils import floordiv_compat +from narwhals._arrow.utils import lit from narwhals._arrow.utils import narwhals_to_native_dtype from narwhals._arrow.utils import native_to_narwhals_dtype +from narwhals._arrow.utils import nulls_like from narwhals._arrow.utils import pad_series from narwhals.exceptions import InvalidOperationError from narwhals.typing import CompliantSeries @@ -37,13 +41,52 @@ from narwhals._arrow.dataframe import ArrowDataFrame from narwhals._arrow.namespace import ArrowNamespace + from narwhals._arrow.typing import ArrowArray + from narwhals._arrow.typing import ArrowChunkedArray + from narwhals._arrow.typing import Incomplete + from narwhals._arrow.typing import Indices + from narwhals._arrow.typing import NullPlacement + from narwhals._arrow.typing import Order + from narwhals._arrow.typing import TieBreaker + from narwhals._arrow.typing import _AsPyType + from narwhals._arrow.typing import _BasicDataType from narwhals.dtypes import DType from narwhals.typing import _1DArray from narwhals.typing import _2DArray from narwhals.utils import Version +@overload +def maybe_extract_py_scalar( + value: pa.Scalar[_BasicDataType[_AsPyType]], + return_py_scalar: bool, # noqa: FBT001 +) -> _AsPyType: ... + + +@overload +def maybe_extract_py_scalar( + value: pa.Scalar[pa.StructType], + return_py_scalar: bool, # noqa: FBT001 +) -> list[dict[str, Any]]: ... + + +@overload +def maybe_extract_py_scalar( + value: pa.Scalar[pa.ListType[_BasicDataType[_AsPyType]]], + return_py_scalar: bool, # noqa: FBT001 +) -> list[_AsPyType]: ... + + +@overload +def maybe_extract_py_scalar( + value: pa.Scalar[Any] | Any, + return_py_scalar: bool, # noqa: FBT001 +) -> Any: ... + + def maybe_extract_py_scalar(value: Any, return_py_scalar: bool) -> Any: # noqa: FBT001 + if TYPE_CHECKING: + return value.as_py() if return_py_scalar: return getattr(value, "as_py", lambda: value)() return value @@ -52,14 +95,14 @@ def maybe_extract_py_scalar(value: Any, return_py_scalar: bool) -> Any: # noqa: class ArrowSeries(CompliantSeries): def __init__( self: Self, - native_series: pa.ChunkedArray, + native_series: ArrowChunkedArray, *, name: str, backend_version: tuple[int, ...], version: Version, ) -> None: self._name = name - self._native_series = native_series + self._native_series: ArrowChunkedArray = native_series self._implementation = Implementation.PYARROW self._backend_version = backend_version self._version = version @@ -73,11 +116,12 @@ def _change_version(self: Self, version: Version) -> Self: version=version, ) - def _from_native_series(self: Self, series: pa.ChunkedArray | pa.Array) -> Self: - if isinstance(series, pa.Array): - series = pa.chunked_array([series]) + def _from_native_series( + self: Self, + series: ArrowArray | ArrowChunkedArray, + ) -> Self: return self.__class__( - series, + chunked_array(series), name=self._name, backend_version=self._backend_version, version=self._version, @@ -93,7 +137,7 @@ def _from_iterable( version: Version, ) -> Self: return cls( - pa.chunked_array([data]), + chunked_array([data]), name=name, backend_version=backend_version, version=version, @@ -110,12 +154,12 @@ def __len__(self: Self) -> int: return len(self._native_series) def __eq__(self: Self, other: object) -> Self: # type: ignore[override] - ser, other = broadcast_and_extract_native(self, other, self._backend_version) - return self._from_native_series(pc.equal(ser, other)) + ser, right = broadcast_and_extract_native(self, other, self._backend_version) + return self._from_native_series(pc.equal(ser, right)) def __ne__(self: Self, other: object) -> Self: # type: ignore[override] - ser, other = broadcast_and_extract_native(self, other, self._backend_version) - return self._from_native_series(pc.not_equal(ser, other)) + ser, right = broadcast_and_extract_native(self, other, self._backend_version) + return self._from_native_series(pc.not_equal(ser, right)) def __ge__(self: Self, other: Any) -> Self: ser, other = broadcast_and_extract_native(self, other, self._backend_version) @@ -190,15 +234,15 @@ def __truediv__(self: Self, other: Any) -> Self: ser, other = broadcast_and_extract_native(self, other, self._backend_version) if not isinstance(other, (pa.Array, pa.ChunkedArray)): # scalar - other = pa.scalar(other) + other = lit(other) return self._from_native_series(pc.divide(*cast_for_truediv(ser, other))) def __rtruediv__(self: Self, other: Any) -> Self: - ser, other = broadcast_and_extract_native(self, other, self._backend_version) - if not isinstance(other, (pa.Array, pa.ChunkedArray)): + ser, right = broadcast_and_extract_native(self, other, self._backend_version) + if not isinstance(right, (pa.Array, pa.ChunkedArray)): # scalar - other = pa.scalar(other) - return self._from_native_series(pc.divide(*cast_for_truediv(other, ser))) + right = lit(right) if not isinstance(right, pa.Scalar) else right + return self._from_native_series(pc.divide(*cast_for_truediv(right, ser))) def __mod__(self: Self, other: Any) -> Self: floor_div = (self // other)._native_series @@ -213,10 +257,16 @@ def __rmod__(self: Self, other: Any) -> Self: return self._from_native_series(res) def __invert__(self: Self) -> Self: - return self._from_native_series(pc.invert(self._native_series)) + return self._from_native_series( + pc.invert(self._native_series) # type: ignore[call-overload] + ) + + @property + def _type(self: Self) -> pa.DataType: + return self._native_series.type def len(self: Self, *, _return_py_scalar: bool = True) -> int: - return maybe_extract_py_scalar(len(self._native_series), _return_py_scalar) # type: ignore[no-any-return] + return maybe_extract_py_scalar(len(self._native_series), _return_py_scalar) def filter(self: Self, other: Any) -> Self: if not (isinstance(other, list) and all(isinstance(x, bool) for x in other)): @@ -225,17 +275,20 @@ def filter(self: Self, other: Any) -> Self: ser = self._native_series return self._from_native_series(ser.filter(other)) - def mean(self: Self, *, _return_py_scalar: bool = True) -> int: - return maybe_extract_py_scalar(pc.mean(self._native_series), _return_py_scalar) # type: ignore[no-any-return] + def mean(self: Self, *, _return_py_scalar: bool = True) -> float: + # NOTE: stub overly strict https://github.com/zen-xu/pyarrow-stubs/blob/d97063876720e6a5edda7eb15f4efe07c31b8296/pyarrow-stubs/compute.pyi#L274-L307 + # docs say numeric https://arrow.apache.org/docs/python/generated/pyarrow.compute.mean.html + mean: Incomplete = pc.mean + return maybe_extract_py_scalar(mean(self._native_series), _return_py_scalar) - def median(self: Self, *, _return_py_scalar: bool = True) -> int: + def median(self: Self, *, _return_py_scalar: bool = True) -> float: from narwhals.exceptions import InvalidOperationError if not self.dtype.is_numeric(): msg = "`median` operation not supported for non-numeric input type." raise InvalidOperationError(msg) - return maybe_extract_py_scalar( # type: ignore[no-any-return] + return maybe_extract_py_scalar( pc.approximate_median(self._native_series), _return_py_scalar ) @@ -247,44 +300,44 @@ def max(self: Self, *, _return_py_scalar: bool = True) -> Any: def arg_min(self: Self, *, _return_py_scalar: bool = True) -> int: index_min = pc.index(self._native_series, pc.min(self._native_series)) - return maybe_extract_py_scalar(index_min, _return_py_scalar) # type: ignore[no-any-return] + return maybe_extract_py_scalar(index_min, _return_py_scalar) def arg_max(self: Self, *, _return_py_scalar: bool = True) -> int: index_max = pc.index(self._native_series, pc.max(self._native_series)) - return maybe_extract_py_scalar(index_max, _return_py_scalar) # type: ignore[no-any-return] + return maybe_extract_py_scalar(index_max, _return_py_scalar) - def sum(self: Self, *, _return_py_scalar: bool = True) -> int: - return maybe_extract_py_scalar( # type: ignore[no-any-return] + def sum(self: Self, *, _return_py_scalar: bool = True) -> float: + return maybe_extract_py_scalar( pc.sum(self._native_series, min_count=0), _return_py_scalar ) - def drop_nulls(self: Self) -> ArrowSeries: - return self._from_native_series(pc.drop_null(self._native_series)) + def drop_nulls(self: Self) -> Self: + return self._from_native_series(self._native_series.drop_null()) def shift(self: Self, n: int) -> Self: ca = self._native_series - if n > 0: - result = pa.concat_arrays([pa.nulls(n, ca.type), *ca[:-n].chunks]) + arrays = [nulls_like(n, self), *ca[:-n].chunks] elif n < 0: - result = pa.concat_arrays([*ca[-n:].chunks, pa.nulls(-n, ca.type)]) + arrays = [*ca[-n:].chunks, nulls_like(-n, self)] else: - result = ca - return self._from_native_series(result) + return self._from_native_series(ca) + return self._from_native_series(pa.concat_arrays(arrays)) def std(self: Self, ddof: int, *, _return_py_scalar: bool = True) -> float: - return maybe_extract_py_scalar( # type: ignore[no-any-return] + return maybe_extract_py_scalar( pc.stddev(self._native_series, ddof=ddof), _return_py_scalar ) def var(self: Self, ddof: int, *, _return_py_scalar: bool = True) -> float: - return maybe_extract_py_scalar( # type: ignore[no-any-return] + return maybe_extract_py_scalar( pc.variance(self._native_series, ddof=ddof), _return_py_scalar ) def skew(self: Self, *, _return_py_scalar: bool = True) -> float | None: ser = self._native_series - ser_not_null = pc.drop_null(ser) + # NOTE: stub issue with `pc.subtract`, `pc.mean` and `pa.ChunkedArray` + ser_not_null: Incomplete = ser.drop_null() if len(ser_not_null) == 0: return None elif len(ser_not_null) == 1: @@ -292,20 +345,20 @@ def skew(self: Self, *, _return_py_scalar: bool = True) -> float | None: elif len(ser_not_null) == 2: return 0.0 else: - m = pc.subtract(ser_not_null, pc.mean(ser_not_null)) - m2 = pc.mean(pc.power(m, 2)) - m3 = pc.mean(pc.power(m, 3)) - # Biased population skewness - return maybe_extract_py_scalar( # type: ignore[no-any-return] - pc.divide(m3, pc.power(m2, 1.5)), _return_py_scalar + m = cast( + "pc.NumericArray[Any]", pc.subtract(ser_not_null, pc.mean(ser_not_null)) ) + m2 = pc.mean(pc.power(m, lit(2))) + m3 = pc.mean(pc.power(m, lit(3))) + biased_population_skewness = pc.divide(m3, pc.power(m2, lit(1.5))) + return maybe_extract_py_scalar(biased_population_skewness, _return_py_scalar) def count(self: Self, *, _return_py_scalar: bool = True) -> int: - return maybe_extract_py_scalar(pc.count(self._native_series), _return_py_scalar) # type: ignore[no-any-return] + return maybe_extract_py_scalar(pc.count(self._native_series), _return_py_scalar) def n_unique(self: Self, *, _return_py_scalar: bool = True) -> int: - unique_values = pc.unique(self._native_series) - return maybe_extract_py_scalar( # type: ignore[no-any-return] + unique_values = self._native_series.unique() + return maybe_extract_py_scalar( pc.count(unique_values, mode="all"), _return_py_scalar ) @@ -337,7 +390,9 @@ def __getitem__( self._native_series[idx], return_py_scalar=True ) if isinstance(idx, (Sequence, pa.ChunkedArray)): - return self._from_native_series(self._native_series.take(idx)) + return self._from_native_series( + self._native_series.take(cast("Indices", idx)) + ) return self._from_native_series(self._native_series[idx]) def scatter(self: Self, indices: int | Sequence[int], values: Any) -> Self: @@ -355,7 +410,9 @@ def scatter(self: Self, indices: int | Sequence[int], values: Any) -> Self: values = values.combine_chunks() if not isinstance(values, pa.Array): values = pa.array(values) - result = pc.replace_with_mask(ser, mask, values.take(indices)) + result = pc.replace_with_mask( + ser, cast("list[bool]", mask), values.take(cast("Indices", indices)) + ) return self._from_native_series(result) def to_list(self: Self) -> list[Any]: @@ -384,31 +441,43 @@ def abs(self: Self) -> Self: def cum_sum(self: Self, *, reverse: bool) -> Self: native_series = self._native_series + # NOTE: stub only permits `NumericArray` + # https://github.com/zen-xu/pyarrow-stubs/blob/d97063876720e6a5edda7eb15f4efe07c31b8296/pyarrow-stubs/compute.pyi#L140 + cum_sum: Incomplete = pc.cumulative_sum result = ( - pc.cumulative_sum(native_series, skip_nulls=True) + cum_sum(native_series, skip_nulls=True) if not reverse - else pc.cumulative_sum(native_series[::-1], skip_nulls=True)[::-1] + else cum_sum(native_series[::-1], skip_nulls=True)[::-1] ) return self._from_native_series(result) def round(self: Self, decimals: int) -> Self: + # NOTE: stub only permits `NumericArray` + # https://github.com/zen-xu/pyarrow-stubs/blob/d97063876720e6a5edda7eb15f4efe07c31b8296/pyarrow-stubs/compute.pyi#L140 + pc_round: Incomplete = pc.round return self._from_native_series( - pc.round(self._native_series, decimals, round_mode="half_towards_infinity") + pc_round(self._native_series, decimals, round_mode="half_towards_infinity") ) def diff(self: Self) -> Self: - return self._from_native_series( - pc.pairwise_diff(self._native_series.combine_chunks()) - ) + # NOTE: stub only permits `ChunkedArray[TemporalScalar]` + # (https://github.com/zen-xu/pyarrow-stubs/blob/d97063876720e6a5edda7eb15f4efe07c31b8296/pyarrow-stubs/compute.pyi#L145-L148) + diff: Incomplete = pc.pairwise_diff + return self._from_native_series(diff(self._native_series.combine_chunks())) def any(self: Self, *, _return_py_scalar: bool = True) -> bool: - return maybe_extract_py_scalar( # type: ignore[no-any-return] - pc.any(self._native_series, min_count=0), _return_py_scalar + # NOTE: stub restricts to `BooleanArray`, should be based on truthiness + # Copies `pc.all` + pc_any: Incomplete = pc.any + return maybe_extract_py_scalar( + pc_any(self._native_series, min_count=0), _return_py_scalar ) def all(self: Self, *, _return_py_scalar: bool = True) -> bool: - return maybe_extract_py_scalar( # type: ignore[no-any-return] - pc.all(self._native_series, min_count=0), _return_py_scalar + # NOTE: stub restricts to `BooleanArray`, should be based on truthiness + pc_all: Incomplete = pc.all + return maybe_extract_py_scalar( + pc_all(self._native_series, min_count=0), _return_py_scalar ) def is_between( @@ -453,11 +522,11 @@ def is_nan(self: Self) -> Self: def cast(self: Self, dtype: DType) -> Self: ser = self._native_series - dtype = narwhals_to_native_dtype(dtype, self._version) - return self._from_native_series(pc.cast(ser, dtype)) + data_type = narwhals_to_native_dtype(dtype, self._version) + return self._from_native_series(pc.cast(ser, data_type)) def null_count(self: Self, *, _return_py_scalar: bool = True) -> int: - return maybe_extract_py_scalar(self._native_series.null_count, _return_py_scalar) # type: ignore[no-any-return] + return maybe_extract_py_scalar(self._native_series.null_count, _return_py_scalar) def head(self: Self, n: int) -> Self: ser = self._native_series @@ -522,16 +591,16 @@ def value_counts( index_name_ = "index" if self._name is None else self._name value_name_ = name or ("proportion" if normalize else "count") - val_count = pc.value_counts(self._native_series) - values = val_count.field("values") - counts = val_count.field("counts") + val_counts = pc.value_counts(self._native_series) + values = val_counts.field("values") + counts = cast("ArrowChunkedArray", val_counts.field("counts")) if normalize: - counts = pc.divide(*cast_for_truediv(counts, pc.sum(counts))) + arrays = [values, pc.divide(*cast_for_truediv(counts, pc.sum(counts)))] + else: + arrays = [values, counts] - val_count = pa.Table.from_arrays( - [values, counts], names=[index_name_, value_name_] - ) + val_count = pa.Table.from_arrays(arrays, names=[index_name_, value_name_]) if sort: val_count = val_count.sort_by([(value_name_, "descending")]) @@ -544,13 +613,9 @@ def value_counts( ) def zip_with(self: Self, mask: Self, other: Self) -> Self: - mask = mask._native_series.combine_chunks() + cond = mask._native_series.combine_chunks() return self._from_native_series( - pc.if_else( - mask, - self._native_series, - other._native_series, - ) + pc.if_else(cond, self._native_series, other._native_series) ) def sample( @@ -572,8 +637,7 @@ def sample( rng = np.random.default_rng(seed=seed) idx = np.arange(0, num_rows) mask = rng.choice(idx, size=n, replace=with_replacement) - - return self._from_native_series(pc.take(ser, mask)) + return self._from_native_series(ser.take(mask)) def fill_null( self: Self, @@ -584,10 +648,10 @@ def fill_null( import numpy as np # ignore-banned-import def fill_aux( - arr: pa.Array, + arr: ArrowArray | ArrowChunkedArray, limit: int, direction: Literal["forward", "backward"] | None = None, - ) -> pa.Array: + ) -> ArrowArray: # this algorithm first finds the indices of the valid values to fill all the null value positions # then it calculates the distance of each new index and the original index # if the distance is equal to or less than the limit and the original value is null, it is replaced @@ -602,10 +666,7 @@ def fill_aux( )[::-1] distance = valid_index - indices return pc.if_else( - pc.and_( - pc.is_null(arr), - pc.less_equal(distance, pa.scalar(limit)), - ), + pc.and_(pc.is_null(arr), pc.less_equal(distance, lit(limit))), arr.take(valid_index), arr, ) @@ -614,7 +675,7 @@ def fill_aux( dtype = ser.type if value is not None: - res_ser = self._from_native_series(pc.fill_null(ser, pa.scalar(value, dtype))) + res_ser = self._from_native_series(pc.fill_null(ser, lit(value, dtype))) # type: ignore[attr-defined] elif limit is None: fill_func = ( pc.fill_null_forward if strategy == "forward" else pc.fill_null_backward @@ -639,15 +700,15 @@ def to_frame(self: Self) -> ArrowDataFrame: def to_pandas(self: Self) -> pd.Series: import pandas as pd # ignore-banned-import() - return pd.Series(self._native_series, name=self.name) + return pd.Series(self._native_series, name=self.name) # pyright: ignore[reportArgumentType, reportCallIssue] def to_polars(self: Self) -> pl.Series: import polars as pl # ignore-banned-import return pl.from_arrow(self._native_series) # type: ignore[return-value] - def is_unique(self: Self) -> ArrowSeries: - return self.to_frame().is_unique().alias(self.name) + def is_unique(self: Self) -> Self: + return self.to_frame().is_unique().alias(self.name) # type: ignore[return-value] def is_first_distinct(self: Self) -> Self: import numpy as np # ignore-banned-import @@ -689,15 +750,15 @@ def is_sorted(self: Self, *, descending: bool) -> bool: result = pc.all(pc.greater_equal(ser[:-1], ser[1:])) else: result = pc.all(pc.less_equal(ser[:-1], ser[1:])) - return maybe_extract_py_scalar(result, return_py_scalar=True) # type: ignore[no-any-return] + return maybe_extract_py_scalar(result, return_py_scalar=True) - def unique(self: Self, *, maintain_order: bool) -> ArrowSeries: + def unique(self: Self, *, maintain_order: bool) -> Self: # TODO(marco): `pc.unique` seems to always maintain order, is that guaranteed? - return self._from_native_series(pc.unique(self._native_series)) + return self._from_native_series(self._native_series.unique()) def replace_strict( self: Self, old: Sequence[Any], new: Sequence[Any], *, return_dtype: DType | None - ) -> ArrowSeries: + ) -> Self: # https://stackoverflow.com/a/79111029/4451315 idxs = pc.index_in(self._native_series, pa.array(old)) result_native = pc.take(pa.array(new), idxs) @@ -713,15 +774,14 @@ def replace_strict( raise ValueError(msg) return result - def sort(self: Self, *, descending: bool, nulls_last: bool) -> ArrowSeries: + def sort(self: Self, *, descending: bool, nulls_last: bool) -> Self: series = self._native_series - order = "descending" if descending else "ascending" - null_placement = "at_end" if nulls_last else "at_start" + order: Order = "descending" if descending else "ascending" + null_placement: NullPlacement = "at_end" if nulls_last else "at_start" sorted_indices = pc.array_sort_indices( series, order=order, null_placement=null_placement ) - - return self._from_native_series(pc.take(series, sorted_indices)) + return self._from_native_series(series.take(sorted_indices)) def to_dummies(self: Self, *, separator: str, drop_first: bool) -> ArrowDataFrame: import numpy as np # ignore-banned-import @@ -730,7 +790,8 @@ def to_dummies(self: Self, *, separator: str, drop_first: bool) -> ArrowDataFram series = self._native_series name = self._name - da = series.dictionary_encode(null_encoding="encode").combine_chunks() + # NOTE: stub is missing attributes (https://arrow.apache.org/docs/python/generated/pyarrow.DictionaryArray.html) + da: Incomplete = series.combine_chunks().dictionary_encode(null_encoding="encode") columns: _2DArray = np.zeros((len(da.dictionary), len(da)), np.int8) columns[da.indices, np.arange(len(da))] = 1 @@ -764,7 +825,7 @@ def quantile( *, _return_py_scalar: bool = True, ) -> float: - return maybe_extract_py_scalar( # type: ignore[no-any-return] + return maybe_extract_py_scalar( pc.quantile(self._native_series, q=quantile, interpolation=interpolation)[0], _return_py_scalar, ) @@ -782,18 +843,21 @@ def clip( _, upper_bound = broadcast_and_extract_native( self, upper_bound, self._backend_version ) - arr = pc.max_element_wise(arr, lower_bound) - arr = pc.min_element_wise(arr, upper_bound) + # NOTE: stubs are missing `ChunkedArray` support + # https://github.com/zen-xu/pyarrow-stubs/blob/d97063876720e6a5edda7eb15f4efe07c31b8296/pyarrow-stubs/compute.pyi#L948-L954 + max_element_wise: Incomplete = pc.max_element_wise + arr = max_element_wise(arr, lower_bound) + arr = cast("ArrowChunkedArray", pc.min_element_wise(arr, upper_bound)) return self._from_native_series(arr) - def to_arrow(self: Self) -> pa.Array: + def to_arrow(self: Self) -> ArrowArray: return self._native_series.combine_chunks() - def mode(self: Self) -> ArrowSeries: + def mode(self: Self) -> Self: plx = self.__narwhals_namespace__() col_token = generate_temporary_column_name(n_bytes=8, columns=[self.name]) - return self.value_counts( + return self.value_counts( # type: ignore[return-value] name=col_token, normalize=False, sort=False, @@ -812,7 +876,7 @@ def cum_min(self: Self, *, reverse: bool) -> Self: msg = "cum_min method is not supported for pyarrow < 13.0.0" raise NotImplementedError(msg) - native_series = self._native_series + native_series = cast("Any", self._native_series) result = ( pc.cumulative_min(native_series, skip_nulls=True) @@ -826,7 +890,7 @@ def cum_max(self: Self, *, reverse: bool) -> Self: msg = "cum_max method is not supported for pyarrow < 13.0.0" raise NotImplementedError(msg) - native_series = self._native_series + native_series = cast("Any", self._native_series) result = ( pc.cumulative_max(native_series, skip_nulls=True) @@ -840,7 +904,7 @@ def cum_prod(self: Self, *, reverse: bool) -> Self: msg = "cum_max method is not supported for pyarrow < 13.0.0" raise NotImplementedError(msg) - native_series = self._native_series + native_series = cast("Any", self._native_series) result = ( pc.cumulative_prod(native_series, skip_nulls=True) @@ -957,6 +1021,9 @@ def rolling_var( count_in_window = valid_count - valid_count.shift(window_size).fill_null( value=0, strategy=None, limit=None ) + # NOTE: stubs are missing `ChunkedArray` support + # https://github.com/zen-xu/pyarrow-stubs/blob/d97063876720e6a5edda7eb15f4efe07c31b8296/pyarrow-stubs/compute.pyi#L948-L954 + max_element_wise: Incomplete = pc.max_element_wise result = self._from_native_series( pc.if_else( @@ -965,7 +1032,7 @@ def rolling_var( None, ) ) / self._from_native_series( - pc.max_element_wise((count_in_window - ddof)._native_series, 0) + max_element_wise((count_in_window - ddof)._native_series, 0) ) return result[offset:] @@ -1000,18 +1067,20 @@ def rank( # ignore-banned-import - sort_keys = "descending" if descending else "ascending" - tiebreaker = "first" if method == "ordinal" else method + sort_keys: Order = "descending" if descending else "ascending" + tiebreaker: TieBreaker = "first" if method == "ordinal" else method - native_series = self._native_series + native_series: ArrowChunkedArray | ArrowArray if self._backend_version < (14, 0, 0): # pragma: no cover - native_series = native_series.combine_chunks() + native_series = self._native_series.combine_chunks() + else: + native_series = self._native_series null_mask = pc.is_null(native_series) rank = pc.rank(native_series, sort_keys=sort_keys, tiebreaker=tiebreaker) - result = pc.if_else(null_mask, pa.scalar(None), rank) + result = pc.if_else(null_mask, lit(None, native_series.type), rank) return self._from_native_series(result) def hist( # noqa: PLR0915 @@ -1028,26 +1097,37 @@ def hist( # noqa: PLR0915 from narwhals._arrow.dataframe import ArrowDataFrame - def _hist_from_bin_count( - bin_count: int, - ) -> tuple[Sequence[int], Sequence[int | float], Sequence[int | float]]: + def _hist_from_bin_count(bin_count: int): # type: ignore[no-untyped-def] # noqa: ANN202 d = pc.min_max(self._native_series) lower, upper = d["min"], d["max"] pad_lowest_bin = False + pa_float = pa.type_for_alias("float") if lower == upper: - range_ = pa.scalar(1.0) - width = pc.divide(range_, bin_count) - lower = pc.subtract(lower, 0.5) - upper = pc.add(upper, 0.5) + range_ = lit(1.0) + mid = lit(0.5) + width = pc.divide(range_, lit(bin_count)) + lower = pc.subtract(lower, mid) + upper = pc.add(upper, mid) else: pad_lowest_bin = True range_ = pc.subtract(upper, lower) - width = pc.divide(range_.cast("float"), float(bin_count)) + width = pc.divide(pc.cast(range_, pa_float), lit(float(bin_count))) - bin_proportions = pc.divide(pc.subtract(self._native_series, lower), width) - bin_indices = pc.floor(bin_proportions) + bin_proportions = pc.divide( + pc.subtract( + cast("pc.NumericOrTemporalArray", self._native_series), lower + ), + width, + ) + bin_indices: ArrowChunkedArray = cast( + "ArrowChunkedArray", pc.floor(bin_proportions) + ) - bin_indices = pc.if_else( # shift bins so they are right-closed + # NOTE: stubs leave unannotated + if_else: Incomplete = pc.if_else + + # shift bins so they are right-closed + bin_indices = if_else( pc.and_( pc.equal(bin_indices, bin_proportions), pc.greater(bin_indices, 0), @@ -1055,6 +1135,9 @@ def _hist_from_bin_count( pc.subtract(bin_indices, 1), bin_indices, ) + possible = pa.Table.from_arrays( + [pa.Array.from_pandas(np.arange(bin_count, dtype="int64"))], ["values"] + ) counts = ( # count bin id occurrences pa.Table.from_arrays( pc.value_counts(bin_indices).flatten(), @@ -1063,38 +1146,31 @@ def _hist_from_bin_count( # nan values are implicitly dropped in value_counts .filter(~pc.field("values").is_nan()) .cast(pa.schema([("values", pa.int64()), ("counts", pa.int64())])) - .join( # align bin ids to all possible bin ids (populate in missing bins) - pa.Table.from_arrays( - [np.arange(bin_count, dtype="int64")], ["values"] - ), - keys="values", - join_type="right outer", - ) + # align bin ids to all possible bin ids (populate in missing bins) + .join(possible, keys="values", join_type="right outer") .sort_by("values") ) - counts = counts.set_column( # empty bin intervals should have a 0 count - 0, "counts", pc.coalesce(counts.column("counts"), 0) + # empty bin intervals should have a 0 count + counts_coalesce = cast( + "ArrowArray", + pc.coalesce(cast("ArrowArray", counts.column("counts")), lit(0)), ) + counts = counts.set_column(0, "counts", counts_coalesce) # extract left/right side of the intervals bin_left = pc.add(lower, pc.multiply(counts.column("values"), width)) bin_right = pc.add(bin_left, width) if pad_lowest_bin: - bin_left = pa.chunked_array( - [ # pad lowest bin by 1% of range - [ - pc.subtract( - bin_left[0], pc.multiply(range_.cast("float"), 0.001) - ) - ], - bin_left[1:], # pyarrow==11.0 needs to infer - ] - ) + # pad lowest bin by 1% of range + lowest_padded = [ + pc.subtract( + bin_left[0], pc.multiply(pc.cast(range_, pa_float), lit(0.001)) + ) + ] + bin_left = chunked_array([lowest_padded, cast("Any", bin_left[1:])]) return counts.column("counts"), bin_left, bin_right - def _hist_from_bins( - bins: Sequence[int | float], - ) -> tuple[Sequence[int], Sequence[int | float], Sequence[int | float]]: + def _hist_from_bins(bins: Sequence[int | float]): # type: ignore[no-untyped-def] # noqa: ANN202 bin_indices = np.searchsorted(bins, self._native_series, side="left") obs_cats, obs_counts = np.unique(bin_indices, return_counts=True) obj_cats = np.arange(1, len(bins)) @@ -1105,9 +1181,6 @@ def _hist_from_bins( bin_left = bins[:-1] return counts, bin_left, bin_right - counts: Sequence[int] - bin_left: Sequence[int | float] - bin_right: Sequence[int | float] if bins is not None: if len(bins) < 2: counts, bin_left, bin_right = [], [], [] @@ -1125,7 +1198,7 @@ def _hist_from_bins( msg = "must provide one of `bin_count` or `bins`" raise InvalidOperationError(msg) - data: dict[str, Sequence[int | float | str]] = {} + data: dict[str, Any] = {} if include_breakpoint: data["breakpoint"] = bin_right data["count"] = counts @@ -1151,13 +1224,10 @@ def __contains__(self: Self, other: Any) -> bool: try: native_series = self._native_series other_ = ( - pa.scalar(other) - if other is not None - else pa.scalar(None, type=native_series.type) + lit(other) if other is not None else lit(None, type=native_series.type) ) - return maybe_extract_py_scalar( # type: ignore[no-any-return] - pc.is_in(other_, native_series), - return_py_scalar=True, + return maybe_extract_py_scalar( + pc.is_in(other_, native_series), return_py_scalar=True ) except (ArrowInvalid, ArrowNotImplementedError, ArrowTypeError) as exc: from narwhals.exceptions import InvalidOperationError diff --git a/narwhals/_arrow/series_cat.py b/narwhals/_arrow/series_cat.py index 4f8def0f73..730903427d 100644 --- a/narwhals/_arrow/series_cat.py +++ b/narwhals/_arrow/series_cat.py @@ -8,15 +8,16 @@ from typing_extensions import Self from narwhals._arrow.series import ArrowSeries + from narwhals._arrow.typing import Incomplete class ArrowSeriesCatNamespace: def __init__(self: Self, series: ArrowSeries) -> None: - self._compliant_series = series + self._compliant_series: ArrowSeries = series def get_categories(self: Self) -> ArrowSeries: - ca = self._compliant_series._native_series - out = pa.chunked_array( - [pa.concat_arrays(x.dictionary for x in ca.chunks).unique()] + # NOTE: Should be `list[pa.DictionaryArray]`, but `DictionaryArray` has no attributes + chunks: Incomplete = self._compliant_series._native_series.chunks + return self._compliant_series._from_native_series( + pa.concat_arrays(x.dictionary for x in chunks).unique() ) - return self._compliant_series._from_native_series(out) diff --git a/narwhals/_arrow/series_dt.py b/narwhals/_arrow/series_dt.py index 7f10324def..3ba677523c 100644 --- a/narwhals/_arrow/series_dt.py +++ b/narwhals/_arrow/series_dt.py @@ -1,23 +1,27 @@ from __future__ import annotations from typing import TYPE_CHECKING +from typing import cast import pyarrow as pa import pyarrow.compute as pc from narwhals._arrow.utils import floordiv_compat +from narwhals._arrow.utils import lit from narwhals.utils import import_dtypes_module +from narwhals.utils import isinstance_or_issubclass if TYPE_CHECKING: from typing_extensions import Self from narwhals._arrow.series import ArrowSeries + from narwhals._arrow.typing import ArrowChunkedArray from narwhals.typing import TimeUnit class ArrowSeriesDateTimeNamespace: def __init__(self: Self, series: ArrowSeries) -> None: - self._compliant_series = series + self._compliant_series: ArrowSeries = series def to_string(self: Self, format: str) -> ArrowSeries: # noqa: A002 # PyArrow differs from other libraries in that %S also prints out @@ -29,33 +33,28 @@ def to_string(self: Self, format: str) -> ArrowSeries: # noqa: A002 ) def replace_time_zone(self: Self, time_zone: str | None) -> ArrowSeries: + ser: ArrowSeries = self._compliant_series if time_zone is not None: - result = pc.assume_timezone( - pc.local_timestamp(self._compliant_series._native_series), time_zone - ) + result = pc.assume_timezone(pc.local_timestamp(ser._native_series), time_zone) else: - result = pc.local_timestamp(self._compliant_series._native_series) + result = pc.local_timestamp(ser._native_series) return self._compliant_series._from_native_series(result) def convert_time_zone(self: Self, time_zone: str) -> ArrowSeries: if self._compliant_series.dtype.time_zone is None: # type: ignore[attr-defined] - result = self.replace_time_zone("UTC")._native_series.cast( - pa.timestamp(self._compliant_series._native_series.type.unit, time_zone) - ) + ser: ArrowSeries = self.replace_time_zone("UTC") else: - result = self._compliant_series._native_series.cast( - pa.timestamp(self._compliant_series._native_series.type.unit, time_zone) - ) - + ser = self._compliant_series + native_type = pa.timestamp(ser._type.unit, time_zone) # type: ignore[attr-defined] + result = ser._native_series.cast(native_type) return self._compliant_series._from_native_series(result) def timestamp(self: Self, time_unit: TimeUnit) -> ArrowSeries: - s = self._compliant_series._native_series - dtype = self._compliant_series.dtype - dtypes = import_dtypes_module(self._compliant_series._version) - if dtype == dtypes.Datetime: - unit = dtype.time_unit # type: ignore[attr-defined] - s_cast = s.cast(pa.int64()) + ser: ArrowSeries = self._compliant_series + dtypes = import_dtypes_module(ser._version) + if isinstance_or_issubclass(ser.dtype, dtypes.Datetime): + unit = ser.dtype.time_unit + s_cast = ser._native_series.cast(pa.int64()) if unit == "ns": if time_unit == "ns": result = s_cast @@ -65,36 +64,36 @@ def timestamp(self: Self, time_unit: TimeUnit) -> ArrowSeries: result = floordiv_compat(s_cast, 1_000_000) elif unit == "us": if time_unit == "ns": - result = pc.multiply(s_cast, 1_000) + result = cast("ArrowChunkedArray", pc.multiply(s_cast, 1_000)) elif time_unit == "us": result = s_cast else: result = floordiv_compat(s_cast, 1_000) elif unit == "ms": if time_unit == "ns": - result = pc.multiply(s_cast, 1_000_000) + result = cast("ArrowChunkedArray", pc.multiply(s_cast, 1_000_000)) elif time_unit == "us": - result = pc.multiply(s_cast, 1_000) + result = cast("ArrowChunkedArray", pc.multiply(s_cast, 1_000)) else: result = s_cast elif unit == "s": if time_unit == "ns": - result = pc.multiply(s_cast, 1_000_000_000) + result = cast("ArrowChunkedArray", pc.multiply(s_cast, 1_000_000_000)) elif time_unit == "us": - result = pc.multiply(s_cast, 1_000_000) + result = cast("ArrowChunkedArray", pc.multiply(s_cast, 1_000_000)) else: - result = pc.multiply(s_cast, 1_000) + result = cast("ArrowChunkedArray", pc.multiply(s_cast, 1_000)) else: # pragma: no cover msg = f"unexpected time unit {unit}, please report an issue at https://github.com/narwhals-dev/narwhals" raise AssertionError(msg) - elif dtype == dtypes.Date: - time_s = pc.multiply(s.cast(pa.int32()), 86400) + elif isinstance_or_issubclass(ser.dtype, dtypes.Date): + time_s = pc.multiply(ser._native_series.cast(pa.int32()), 86400) if time_unit == "ns": - result = pc.multiply(time_s, 1_000_000_000) + result = cast("ArrowChunkedArray", pc.multiply(time_s, 1_000_000_000)) elif time_unit == "us": - result = pc.multiply(time_s, 1_000_000) + result = cast("ArrowChunkedArray", pc.multiply(time_s, 1_000_000)) else: - result = pc.multiply(time_s, 1_000) + result = cast("ArrowChunkedArray", pc.multiply(time_s, 1_000)) else: msg = "Input should be either of Date or Datetime type" raise TypeError(msg) @@ -141,15 +140,16 @@ def millisecond(self: Self) -> ArrowSeries: ) def microsecond(self: Self) -> ArrowSeries: - arr = self._compliant_series._native_series - result = pc.add(pc.multiply(pc.millisecond(arr), 1000), pc.microsecond(arr)) - + ser: ArrowSeries = self._compliant_series + arr = ser._native_series + result = pc.add(pc.multiply(pc.millisecond(arr), lit(1000)), pc.microsecond(arr)) return self._compliant_series._from_native_series(result) def nanosecond(self: Self) -> ArrowSeries: - arr = self._compliant_series._native_series + ser: ArrowSeries = self._compliant_series result = pc.add( - pc.multiply(self.microsecond()._native_series, 1000), pc.nanosecond(arr) + pc.multiply(self.microsecond()._native_series, lit(1000)), + pc.nanosecond(ser._native_series), ) return self._compliant_series._from_native_series(result) @@ -164,72 +164,63 @@ def weekday(self: Self) -> ArrowSeries: ) def total_minutes(self: Self) -> ArrowSeries: - arr = self._compliant_series._native_series - unit = arr.type.unit - + ser: ArrowSeries = self._compliant_series unit_to_minutes_factor = { "s": 60, # seconds "ms": 60 * 1e3, # milli "us": 60 * 1e6, # micro "ns": 60 * 1e9, # nano } - - factor = pa.scalar(unit_to_minutes_factor[unit], type=pa.int64()) + unit = ser._type.unit # type: ignore[attr-defined] + factor = lit(unit_to_minutes_factor[unit], type=pa.int64()) return self._compliant_series._from_native_series( - pc.cast(pc.divide(arr, factor), pa.int64()) + pc.cast(pc.divide(ser._native_series, factor), pa.int64()) ) def total_seconds(self: Self) -> ArrowSeries: - arr = self._compliant_series._native_series - unit = arr.type.unit - + ser: ArrowSeries = self._compliant_series unit_to_seconds_factor = { "s": 1, # seconds "ms": 1e3, # milli "us": 1e6, # micro "ns": 1e9, # nano } - factor = pa.scalar(unit_to_seconds_factor[unit], type=pa.int64()) - + unit = ser._type.unit # type: ignore[attr-defined] + factor = lit(unit_to_seconds_factor[unit], type=pa.int64()) return self._compliant_series._from_native_series( - pc.cast(pc.divide(arr, factor), pa.int64()) + pc.cast(pc.divide(ser._native_series, factor), pa.int64()) ) def total_milliseconds(self: Self) -> ArrowSeries: - arr = self._compliant_series._native_series - unit = arr.type.unit - + ser: ArrowSeries = self._compliant_series + arr = ser._native_series + unit = ser._type.unit # type: ignore[attr-defined] unit_to_milli_factor = { "s": 1e3, # seconds "ms": 1, # milli "us": 1e3, # micro "ns": 1e6, # nano } - - factor = pa.scalar(unit_to_milli_factor[unit], type=pa.int64()) - + factor = lit(unit_to_milli_factor[unit], type=pa.int64()) if unit == "s": return self._compliant_series._from_native_series( pc.cast(pc.multiply(arr, factor), pa.int64()) ) - return self._compliant_series._from_native_series( pc.cast(pc.divide(arr, factor), pa.int64()) ) def total_microseconds(self: Self) -> ArrowSeries: - arr = self._compliant_series._native_series - unit = arr.type.unit - + ser: ArrowSeries = self._compliant_series + arr = ser._native_series + unit = ser._type.unit # type: ignore[attr-defined] unit_to_micro_factor = { "s": 1e6, # seconds "ms": 1e3, # milli "us": 1, # micro "ns": 1e3, # nano } - - factor = pa.scalar(unit_to_micro_factor[unit], type=pa.int64()) - + factor = lit(unit_to_micro_factor[unit], type=pa.int64()) if unit in {"s", "ms"}: return self._compliant_series._from_native_series( pc.cast(pc.multiply(arr, factor), pa.int64()) @@ -239,18 +230,15 @@ def total_microseconds(self: Self) -> ArrowSeries: ) def total_nanoseconds(self: Self) -> ArrowSeries: - arr = self._compliant_series._native_series - unit = arr.type.unit - + ser: ArrowSeries = self._compliant_series unit_to_nano_factor = { "s": 1e9, # seconds "ms": 1e6, # milli "us": 1e3, # micro "ns": 1, # nano } - - factor = pa.scalar(unit_to_nano_factor[unit], type=pa.int64()) - + unit = ser._type.unit # type: ignore[attr-defined] + factor = lit(unit_to_nano_factor[unit], type=pa.int64()) return self._compliant_series._from_native_series( - pc.cast(pc.multiply(arr, factor), pa.int64()) + pc.cast(pc.multiply(ser._native_series, factor), pa.int64()) ) diff --git a/narwhals/_arrow/series_list.py b/narwhals/_arrow/series_list.py index c65e73b508..05fa3b3f5e 100644 --- a/narwhals/_arrow/series_list.py +++ b/narwhals/_arrow/series_list.py @@ -13,7 +13,7 @@ class ArrowSeriesListNamespace: def __init__(self: Self, series: ArrowSeries) -> None: - self._arrow_series = series + self._arrow_series: ArrowSeries = series def len(self: Self) -> ArrowSeries: return self._arrow_series._from_native_series( diff --git a/narwhals/_arrow/series_str.py b/narwhals/_arrow/series_str.py index 4d67f36d32..4f2b47ea6f 100644 --- a/narwhals/_arrow/series_str.py +++ b/narwhals/_arrow/series_str.py @@ -5,17 +5,20 @@ import pyarrow.compute as pc +from narwhals._arrow.utils import lit from narwhals._arrow.utils import parse_datetime_format if TYPE_CHECKING: + import pyarrow as pa from typing_extensions import Self from narwhals._arrow.series import ArrowSeries + from narwhals._arrow.typing import Incomplete class ArrowSeriesStringNamespace: def __init__(self: Self, series: ArrowSeries) -> None: - self._compliant_series = series + self._compliant_series: ArrowSeries = series def len_chars(self: Self) -> ArrowSeries: return self._compliant_series._from_native_series( @@ -25,15 +28,10 @@ def len_chars(self: Self) -> ArrowSeries: def replace( self: Self, pattern: str, value: str, *, literal: bool, n: int ) -> ArrowSeries: - method = "replace_substring" if literal else "replace_substring_regex" - return self._compliant_series._from_native_series( - getattr(pc, method)( - self._compliant_series._native_series, - pattern=pattern, - replacement=value, - max_replacements=n, - ) - ) + compliant = self._compliant_series + fn = pc.replace_substring if literal else pc.replace_substring_regex + arr = fn(compliant._native_series, pattern, replacement=value, max_replacements=n) + return compliant._from_native_series(arr) def replace_all( self: Self, pattern: str, value: str, *, literal: bool @@ -51,12 +49,12 @@ def strip_chars(self: Self, characters: str | None) -> ArrowSeries: def starts_with(self: Self, prefix: str) -> ArrowSeries: return self._compliant_series._from_native_series( - pc.equal(self.slice(0, len(prefix))._native_series, prefix) + pc.equal(self.slice(0, len(prefix))._native_series, lit(prefix)) ) def ends_with(self: Self, suffix: str) -> ArrowSeries: return self._compliant_series._from_native_series( - pc.equal(self.slice(-len(suffix), None)._native_series, suffix) + pc.equal(self.slice(-len(suffix), None)._native_series, lit(suffix)) ) def contains(self: Self, pattern: str, *, literal: bool) -> ArrowSeries: @@ -70,16 +68,17 @@ def slice(self: Self, offset: int, length: int | None) -> ArrowSeries: return self._compliant_series._from_native_series( pc.utf8_slice_codeunits( self._compliant_series._native_series, start=offset, stop=stop - ), + ) ) def to_datetime(self: Self, format: str | None) -> ArrowSeries: # noqa: A002 - if format is None: - format = parse_datetime_format(self._compliant_series._native_series) - - return self._compliant_series._from_native_series( - pc.strptime(self._compliant_series._native_series, format=format, unit="us") + native = self._compliant_series._native_series + format = parse_datetime_format(native) if format is None else format + strptime: Incomplete = pc.strptime + timestamp_array: pa.Array[pa.TimestampScalar] = strptime( + native, format=format, unit="us" ) + return self._compliant_series._from_native_series(timestamp_array) def to_uppercase(self: Self) -> ArrowSeries: return self._compliant_series._from_native_series( diff --git a/narwhals/_arrow/typing.py b/narwhals/_arrow/typing.py index 38b8c9be47..69e2596dd8 100644 --- a/narwhals/_arrow/typing.py +++ b/narwhals/_arrow/typing.py @@ -1,17 +1,55 @@ from __future__ import annotations # pragma: no cover from typing import TYPE_CHECKING # pragma: no cover -from typing import Union # pragma: no cover +from typing import Any # pragma: no cover +from typing import TypeVar # pragma: no cover if TYPE_CHECKING: import sys + from typing import Generic + from typing import Literal if sys.version_info >= (3, 10): from typing import TypeAlias else: from typing_extensions import TypeAlias + import pyarrow as pa + import pyarrow.compute as pc + from pyarrow._stubs_typing import ( # pyright: ignore[reportMissingModuleSource] + Indices, # noqa: F401 + ) + from pyarrow._stubs_typing import ( # pyright: ignore[reportMissingModuleSource] + Mask, # noqa: F401 + ) + from pyarrow._stubs_typing import ( # pyright: ignore[reportMissingModuleSource] + Order, # noqa: F401 + ) + from narwhals._arrow.expr import ArrowExpr from narwhals._arrow.series import ArrowSeries - IntoArrowExpr: TypeAlias = Union[ArrowExpr, ArrowSeries] + IntoArrowExpr: TypeAlias = "ArrowExpr | ArrowSeries" + TieBreaker: TypeAlias = Literal["min", "max", "first", "dense"] + NullPlacement: TypeAlias = Literal["at_start", "at_end"] + + StringArray: TypeAlias = "pc.StringArray" + ArrowChunkedArray: TypeAlias = pa.ChunkedArray[Any] + ArrowArray: TypeAlias = pa.Array[Any] + _AsPyType = TypeVar("_AsPyType") + + class _BasicDataType(pa.DataType, Generic[_AsPyType]): ... + + +Incomplete: TypeAlias = Any # pragma: no cover +""" +Marker for working code that fails on the stubs. + +Common issues: +- Annotated for `Array`, but not `ChunkedArray` +- Relies on typing information that the stubs don't provide statically +- Missing attributes +- Incorrect return types +- Inconsistent use of generic/concrete types +- `_clone_signature` used on signatures that are not identical +""" diff --git a/narwhals/_arrow/utils.py b/narwhals/_arrow/utils.py index 06bb2e5e83..fad917a206 100644 --- a/narwhals/_arrow/utils.py +++ b/narwhals/_arrow/utils.py @@ -1,8 +1,10 @@ from __future__ import annotations from functools import lru_cache +from itertools import chain from typing import TYPE_CHECKING from typing import Any +from typing import Iterable from typing import Sequence from typing import cast from typing import overload @@ -11,18 +13,80 @@ import pyarrow.compute as pc from narwhals.utils import import_dtypes_module +from narwhals.utils import is_compliant_expr +from narwhals.utils import is_compliant_series from narwhals.utils import isinstance_or_issubclass if TYPE_CHECKING: from typing import TypeVar + from typing_extensions import TypeAlias + from typing_extensions import TypeIs + from narwhals._arrow.series import ArrowSeries + from narwhals._arrow.typing import ArrowArray + from narwhals._arrow.typing import ArrowChunkedArray + from narwhals._arrow.typing import Incomplete + from narwhals._arrow.typing import StringArray from narwhals.dtypes import DType + from narwhals.typing import TimeUnit from narwhals.typing import _AnyDArray from narwhals.utils import Version + # NOTE: stubs don't allow for `ChunkedArray[StructArray]` + # Intended to represent the `.chunks` property storing `list[pa.StructArray]` + ChunkedArrayStructArray: TypeAlias = ArrowChunkedArray + _T = TypeVar("_T") + def is_timestamp(t: Any) -> TypeIs[pa.TimestampType[Any, Any]]: ... + def is_duration(t: Any) -> TypeIs[pa.DurationType[Any]]: ... + def is_list(t: Any) -> TypeIs[pa.ListType[Any]]: ... + def is_large_list(t: Any) -> TypeIs[pa.LargeListType[Any]]: ... + def is_fixed_size_list(t: Any) -> TypeIs[pa.FixedSizeListType[Any, Any]]: ... + def is_dictionary( + t: Any, + ) -> TypeIs[pa.DictionaryType[Any, Any, Any]]: ... + def extract_regex( + strings: ArrowChunkedArray, + /, + pattern: str, + *, + options: Any = None, + memory_pool: Any = None, + ) -> ChunkedArrayStructArray: ... +else: + from pyarrow.compute import extract_regex + from pyarrow.types import is_dictionary # noqa: F401 + from pyarrow.types import is_duration + from pyarrow.types import is_fixed_size_list + from pyarrow.types import is_large_list + from pyarrow.types import is_list + from pyarrow.types import is_timestamp + +lit = pa.scalar +"""Alias for `pyarrow.scalar`.""" + + +def chunked_array( + arr: ArrowArray | list[Iterable[pa.Scalar[Any]]] | ArrowChunkedArray, +) -> ArrowChunkedArray: + if isinstance(arr, pa.ChunkedArray): + return arr + if isinstance(arr, list): + return pa.chunked_array(cast("Any", arr)) + else: + return pa.chunked_array([arr], arr.type) + + +def nulls_like(n: int, series: ArrowSeries) -> ArrowArray: + """Create a strongly-typed Array instance with all elements null. + + Uses the type of `series`, without upseting `mypy`. + """ + nulls: Incomplete = pa.nulls + return nulls(n, series._type) + @lru_cache(maxsize=16) def native_to_narwhals_dtype(dtype: pa.DataType, version: Version) -> DType: @@ -59,9 +123,9 @@ def native_to_narwhals_dtype(dtype: pa.DataType, version: Version) -> DType: return dtypes.String() if pa.types.is_date32(dtype): return dtypes.Date() - if pa.types.is_timestamp(dtype): + if is_timestamp(dtype): return dtypes.Datetime(time_unit=dtype.unit, time_zone=dtype.tz) - if pa.types.is_duration(dtype): + if is_duration(dtype): return dtypes.Duration(time_unit=dtype.unit) if pa.types.is_dictionary(dtype): return dtypes.Categorical() @@ -75,9 +139,9 @@ def native_to_narwhals_dtype(dtype: pa.DataType, version: Version) -> DType: for i in range(dtype.num_fields) ] ) - if pa.types.is_list(dtype) or pa.types.is_large_list(dtype): + if is_list(dtype) or is_large_list(dtype): return dtypes.List(native_to_narwhals_dtype(dtype.value_type, version)) - if pa.types.is_fixed_size_list(dtype): + if is_fixed_size_list(dtype): return dtypes.Array( native_to_narwhals_dtype(dtype.value_type, version), dtype.list_size ) @@ -118,7 +182,7 @@ def narwhals_to_native_dtype(dtype: DType | type[DType], version: Version) -> pa if isinstance_or_issubclass(dtype, dtypes.Categorical): return pa.dictionary(pa.uint32(), pa.string()) if isinstance_or_issubclass(dtype, dtypes.Datetime): - time_unit = getattr(dtype, "time_unit", "us") + time_unit: TimeUnit = getattr(dtype, "time_unit", "us") time_zone = getattr(dtype, "time_zone", None) return pa.timestamp(time_unit, tz=time_zone) if isinstance_or_issubclass(dtype, dtypes.Duration): @@ -127,40 +191,40 @@ def narwhals_to_native_dtype(dtype: DType | type[DType], version: Version) -> pa if isinstance_or_issubclass(dtype, dtypes.Date): return pa.date32() if isinstance_or_issubclass(dtype, dtypes.List): - return pa.list_( - value_type=narwhals_to_native_dtype( - dtype.inner, # type: ignore[union-attr] - version=version, - ) - ) + return pa.list_(value_type=narwhals_to_native_dtype(dtype.inner, version=version)) if isinstance_or_issubclass(dtype, dtypes.Struct): return pa.struct( [ - ( - field.name, - narwhals_to_native_dtype( - field.dtype, - version=version, - ), - ) - for field in dtype.fields # type: ignore[union-attr] + (field.name, narwhals_to_native_dtype(field.dtype, version=version)) + for field in dtype.fields ] ) if isinstance_or_issubclass(dtype, dtypes.Array): # pragma: no cover - inner = narwhals_to_native_dtype( - dtype.inner, # type: ignore[union-attr] - version=version, - ) - list_size = dtype.size # type: ignore[union-attr] + inner = narwhals_to_native_dtype(dtype.inner, version=version) + list_size = dtype.size return pa.list_(inner, list_size=list_size) msg = f"Unknown dtype: {dtype}" # pragma: no cover raise AssertionError(msg) +@overload +def broadcast_and_extract_native( + lhs: ArrowSeries, rhs: None, backend_version: tuple[int, ...] +) -> tuple[ArrowChunkedArray, pa.Scalar[Any]]: ... + + +@overload +def broadcast_and_extract_native( + lhs: ArrowSeries, + rhs: ArrowSeries | list[ArrowSeries] | Any, + backend_version: tuple[int, ...], +) -> tuple[ArrowChunkedArray, pa.Scalar[Any] | ArrowChunkedArray]: ... + + def broadcast_and_extract_native( lhs: ArrowSeries, rhs: Any, backend_version: tuple[int, ...] -) -> tuple[pa.ChunkedArray, Any]: +) -> tuple[ArrowChunkedArray, Any]: """Validate RHS of binary operation. If the comparison isn't supported, return `NotImplemented` so that the @@ -172,17 +236,15 @@ def broadcast_and_extract_native( from narwhals._arrow.dataframe import ArrowDataFrame from narwhals._arrow.series import ArrowSeries - if rhs is None: - return lhs._native_series, pa.scalar(None, type=lhs._native_series.type) + if rhs is None: # DONE + return lhs._native_series, lit(None, type=lhs._native_series.type) # If `rhs` is the output of an expression evaluation, then it is # a list of Series. So, we verify that that list is of length-1, # and take the first (and only) element. if isinstance(rhs, list): if len(rhs) > 1: - if hasattr(rhs[0], "__narwhals_expr__") or hasattr( - rhs[0], "__narwhals_series__" - ): + if is_compliant_expr(rhs[0]) or is_compliant_series(rhs[0]): # e.g. `plx.all() + plx.all()` msg = "Multi-output expressions (e.g. `nw.all()` or `nw.col('a', 'b')`) are not supported in this context" raise ValueError(msg) @@ -204,24 +266,18 @@ def broadcast_and_extract_native( fill_value = lhs[0] if backend_version < (13,) and hasattr(fill_value, "as_py"): fill_value = fill_value.as_py() - left_result = pa.chunked_array( - [ - pa.array( - np.full(shape=rhs.len(), fill_value=fill_value), - type=lhs._native_series.type, - ) - ] + arr = pa.array( + np.full(shape=rhs.len(), fill_value=fill_value), + type=lhs._native_series.type, ) - return left_result, rhs._native_series + return chunked_array(arr), rhs._native_series return lhs._native_series, rhs._native_series return lhs._native_series, rhs def broadcast_and_extract_dataframe_comparand( - length: int, - other: Any, - backend_version: tuple[int, ...], -) -> Any: + length: int, other: ArrowSeries, backend_version: tuple[int, ...] +) -> ArrowChunkedArray: """Validate RHS of binary operation. If the comparison isn't supported, return `NotImplemented` so that the @@ -237,7 +293,7 @@ def broadcast_and_extract_dataframe_comparand( value = other._native_series[0] if backend_version < (13,) and hasattr(value, "as_py"): value = value.as_py() - return pa.array(np.full(shape=length, fill_value=value)) + return pa.chunked_array([np.full(shape=length, fill_value=value)]) return other._native_series @@ -260,8 +316,7 @@ def horizontal_concat(dfs: list[pa.Table]) -> pa.Table: if len(set(names)) < len(names): # pragma: no cover msg = "Expected unique column names" raise ValueError(msg) - - arrays = [a for df in dfs for a in df] + arrays = list(chain.from_iterable(df.itercolumns() for df in dfs)) return pa.Table.from_arrays(arrays, names=names) @@ -289,10 +344,10 @@ def diagonal_concat(dfs: list[pa.Table], backend_version: tuple[int, ...]) -> pa Should be in namespace. """ - kwargs = ( + kwargs: dict[str, Any] = ( {"promote": True} if backend_version < (14, 0, 0) - else {"promote_options": "default"} # type: ignore[dict-item] + else {"promote_options": "default"} ) return pa.concat_tables(dfs, **kwargs) @@ -301,10 +356,10 @@ def floordiv_compat(left: Any, right: Any) -> Any: # The following lines are adapted from pandas' pyarrow implementation. # Ref: https://github.com/pandas-dev/pandas/blob/262fcfbffcee5c3116e86a951d8b693f90411e68/pandas/core/arrays/arrow/array.py#L124-L154 if isinstance(left, (int, float)): - left = pa.scalar(left) + left = lit(left) if isinstance(right, (int, float)): - right = pa.scalar(right) + right = lit(right) if pa.types.is_integer(left.type) and pa.types.is_integer(right.type): divided = pc.divide_checked(left, right) @@ -312,16 +367,12 @@ def floordiv_compat(left: Any, right: Any) -> Any: # GH 56676 has_remainder = pc.not_equal(pc.multiply(divided, right), left) has_one_negative_operand = pc.less( - pc.bit_wise_xor(left, right), - pa.scalar(0, type=divided.type), + pc.bit_wise_xor(left, right), lit(0, type=divided.type) ) result = pc.if_else( - pc.and_( - has_remainder, - has_one_negative_operand, - ), + pc.and_(has_remainder, has_one_negative_operand), # GH: 55561 ruff: ignore - pc.subtract(divided, pa.scalar(1, type=divided.type)), + pc.subtract(divided, lit(1, type=divided.type)), divided, ) else: @@ -334,8 +385,12 @@ def floordiv_compat(left: Any, right: Any) -> Any: def cast_for_truediv( - arrow_array: pa.ChunkedArray | pa.Scalar, pa_object: pa.ChunkedArray | pa.Scalar -) -> tuple[pa.ChunkedArray | pa.Scalar, pa.ChunkedArray | pa.Scalar]: + arrow_array: ArrowChunkedArray | pa.Scalar[Any], + pa_object: ArrowChunkedArray | ArrowArray | pa.Scalar[Any], +) -> tuple[ + ArrowChunkedArray | pa.Scalar[Any], + ArrowChunkedArray | ArrowArray | pa.Scalar[Any], +]: # Lifted from: # https://github.com/pandas-dev/pandas/blob/262fcfbffcee5c3116e86a951d8b693f90411e68/pandas/core/arrays/arrow/array.py#L108-L122 # Ensure int / int -> float mirroring Python/Numpy behavior @@ -350,7 +405,9 @@ def cast_for_truediv( return arrow_array, pa_object -def broadcast_series(series: Sequence[ArrowSeries]) -> list[Any]: +def broadcast_series( + series: Sequence[ArrowSeries], +) -> Sequence[ArrowChunkedArray]: lengths = [len(s) for s in series] max_length = max(lengths) fast_path = all(_len == max_length for _len in lengths) @@ -366,7 +423,11 @@ def broadcast_series(series: Sequence[ArrowSeries]) -> list[Any]: value = s_native[0] if s._backend_version < (13,) and hasattr(value, "as_py"): value = value.as_py() - reshaped.append(pa.array([value] * max_length, type=s_native.type)) + arr = cast( + "ArrowChunkedArray", + pa.array([value] * max_length, type=s_native.type), + ) + reshaped.append(arr) else: reshaped.append(s_native) @@ -433,12 +494,23 @@ def convert_str_slice_to_int_slice( TIME_FORMATS = ((HMS_RE, "%H:%M:%S"), (HM_RE, "%H:%M"), (HMS_RE_NO_SEP, "%H%M%S")) -def parse_datetime_format(arr: pa.StringArray) -> str: - """Try to infer datetime format from StringArray.""" - matches = pa.concat_arrays( # converts from ChunkedArray to StructArray - pc.extract_regex(pc.drop_null(arr).slice(0, 10), pattern=FULL_RE).chunks +def _extract_regex_concat_arrays( + strings: ArrowChunkedArray, + /, + pattern: str, + *, + options: Any = None, + memory_pool: Any = None, +) -> pa.StructArray: + r = pa.concat_arrays( + extract_regex(strings, pattern, options=options, memory_pool=memory_pool).chunks ) + return cast("pa.StructArray", r) + +def parse_datetime_format(arr: ArrowChunkedArray) -> str: + """Try to infer datetime format from StringArray.""" + matches = _extract_regex_concat_arrays(arr.drop_null().slice(0, 10), pattern=FULL_RE) if not pc.all(matches.is_valid()).as_py(): msg = ( "Unable to infer datetime format, provided format is not supported. " @@ -446,9 +518,7 @@ def parse_datetime_format(arr: pa.StringArray) -> str: ) raise NotImplementedError(msg) - dates = matches.field("date") separators = matches.field("sep") - times = matches.field("time") tz = matches.field("tz") # separators and time zones must be unique @@ -460,8 +530,8 @@ def parse_datetime_format(arr: pa.StringArray) -> str: msg = "Found multiple timezone values while inferring datetime format." raise ValueError(msg) - date_value = _parse_date_format(dates) - time_value = _parse_time_format(times) + date_value = _parse_date_format(cast("StringArray", matches.field("date"))) + time_value = _parse_time_format(cast("StringArray", matches.field("time"))) sep_value = separators[0].as_py() tz_value = "%z" if tz[0].as_py() else "" @@ -469,7 +539,7 @@ def parse_datetime_format(arr: pa.StringArray) -> str: return f"{date_value}{sep_value}{time_value}{tz_value}" -def _parse_date_format(arr: pa.Array) -> str: +def _parse_date_format(arr: StringArray) -> str: for date_rgx, date_fmt in DATE_FORMATS: matches = pc.extract_regex(arr, pattern=date_rgx) if date_fmt == "%Y%m%d" and pc.all(matches.is_valid()).as_py(): @@ -489,7 +559,7 @@ def _parse_date_format(arr: pa.Array) -> str: raise ValueError(msg) -def _parse_time_format(arr: pa.Array) -> str: +def _parse_time_format(arr: StringArray) -> str: for time_rgx, time_fmt in TIME_FORMATS: matches = pc.extract_regex(arr, pattern=time_rgx) if pc.all(matches.is_valid()).as_py(): @@ -510,8 +580,6 @@ def pad_series( Returns: A tuple containing the padded ArrowSeries and the offset value. """ - # ignore-banned-import - if center: offset_left = window_size // 2 offset_right = offset_left - ( diff --git a/narwhals/_interchange/dataframe.py b/narwhals/_interchange/dataframe.py index 43682695d9..29885b0635 100644 --- a/narwhals/_interchange/dataframe.py +++ b/narwhals/_interchange/dataframe.py @@ -125,7 +125,9 @@ def to_pandas(self: Self) -> pd.DataFrame: raise NotImplementedError(msg) def to_arrow(self: Self) -> pa.Table: - from pyarrow.interchange import from_dataframe # ignore-banned-import() + from pyarrow.interchange.from_dataframe import ( # ignore-banned-import() + from_dataframe, + ) return from_dataframe(self._interchange_frame) diff --git a/narwhals/_pandas_like/series.py b/narwhals/_pandas_like/series.py index 2b047b3cfc..cbbc1c3883 100644 --- a/narwhals/_pandas_like/series.py +++ b/narwhals/_pandas_like/series.py @@ -35,9 +35,9 @@ import pandas as pd import polars as pl - import pyarrow as pa from typing_extensions import Self + from narwhals._arrow.typing import ArrowArray from narwhals._pandas_like.dataframe import PandasLikeDataFrame from narwhals.dtypes import DType from narwhals.typing import _1DArray @@ -865,7 +865,7 @@ def clip( self._native_series.clip(lower_bound, upper_bound, **kwargs) ) - def to_arrow(self: Self) -> pa.Array: + def to_arrow(self: Self) -> ArrowArray: if self._implementation is Implementation.CUDF: return self._native_series.to_arrow() diff --git a/narwhals/dependencies.py b/narwhals/dependencies.py index c4d025e50b..4a4121e8d3 100644 --- a/narwhals/dependencies.py +++ b/narwhals/dependencies.py @@ -22,6 +22,7 @@ from typing_extensions import TypeGuard from typing_extensions import TypeIs + from narwhals._arrow.typing import ArrowChunkedArray from narwhals.dataframe import DataFrame from narwhals.dataframe import LazyFrame from narwhals.series import Series @@ -210,7 +211,9 @@ def is_polars_series(ser: Any) -> TypeIs[pl.Series]: return (pl := get_polars()) is not None and isinstance(ser, pl.Series) -def is_pyarrow_chunked_array(ser: Any) -> TypeIs[pa.ChunkedArray]: +def is_pyarrow_chunked_array( + ser: Any | ArrowChunkedArray, +) -> TypeIs[ArrowChunkedArray]: """Check whether `ser` is a PyArrow ChunkedArray without importing PyArrow.""" return (pa := get_pyarrow()) is not None and isinstance(ser, pa.ChunkedArray) diff --git a/narwhals/dtypes.py b/narwhals/dtypes.py index 46e433bc06..2f408b7a35 100644 --- a/narwhals/dtypes.py +++ b/narwhals/dtypes.py @@ -504,7 +504,7 @@ def __init__( if isinstance(time_zone, timezone): time_zone = str(time_zone) - self.time_unit = time_unit + self.time_unit: TimeUnit = time_unit self.time_zone = time_zone def __eq__(self: Self, other: object) -> bool: @@ -734,6 +734,8 @@ class List(NestedType): List(String) """ + inner: DType | type[DType] + def __init__(self: Self, inner: DType | type[DType]) -> None: self.inner = inner @@ -798,7 +800,7 @@ def __init__( self.size = shape self.shape = (shape, *inner_shape) - elif isinstance(shape, tuple) and isinstance(shape[0], int): + elif isinstance(shape, tuple) and len(shape) != 0 and isinstance(shape[0], int): if len(shape) > 1: inner = Array(inner, shape[1:]) diff --git a/narwhals/series.py b/narwhals/series.py index e174b492ca..409f518c21 100644 --- a/narwhals/series.py +++ b/narwhals/series.py @@ -28,9 +28,9 @@ import pandas as pd import polars as pl - import pyarrow as pa from typing_extensions import Self + from narwhals._arrow.typing import ArrowArray from narwhals.dataframe import DataFrame from narwhals.dtypes import DType from narwhals.typing import _1DArray @@ -186,7 +186,9 @@ def __arrow_c_stream__(self: Self, requested_schema: object | None = None) -> ob if parse_version(pa) < (16, 0): # pragma: no cover msg = f"PyArrow>=16.0.0 is required for `Series.__arrow_c_stream__` for object of type {type(native_series)}" raise ModuleNotFoundError(msg) - ca = pa.chunked_array([self.to_arrow()]) # type: ignore[call-overload, unused-ignore] + from narwhals._arrow.utils import chunked_array + + ca = chunked_array(self.to_arrow()) return ca.__arrow_c_stream__(requested_schema=requested_schema) def to_native(self: Self) -> IntoSeriesT: @@ -2047,7 +2049,7 @@ def gather_every(self: Self, n: int, offset: int = 0) -> Self: self._compliant_series.gather_every(n=n, offset=offset) ) - def to_arrow(self: Self) -> pa.Array: + def to_arrow(self: Self) -> ArrowArray: r"""Convert to arrow. Returns: diff --git a/narwhals/translate.py b/narwhals/translate.py index eff7f31731..c3e6ac8b8d 100644 --- a/narwhals/translate.py +++ b/narwhals/translate.py @@ -44,6 +44,7 @@ import polars as pl import pyarrow as pa + from narwhals._arrow.typing import ArrowChunkedArray from narwhals.dataframe import DataFrame from narwhals.dataframe import LazyFrame from narwhals.series import Series @@ -796,7 +797,7 @@ def get_native_namespace( | pl.LazyFrame | pl.Series | pa.Table - | pa.ChunkedArray, + | ArrowChunkedArray, ) -> Any: """Get native namespace from object. diff --git a/narwhals/utils.py b/narwhals/utils.py index 96f912cd5d..031c466f8a 100644 --- a/narwhals/utils.py +++ b/narwhals/utils.py @@ -983,7 +983,9 @@ def is_ordered_categorical(series: Series[Any]) -> bool: if is_cudf_series(native_series): # pragma: no cover return native_series.cat.ordered # type: ignore[no-any-return] if is_pyarrow_chunked_array(native_series): - return native_series.type.ordered # type: ignore[no-any-return] + from narwhals._arrow.utils import is_dictionary + + return is_dictionary(native_series.type) and native_series.type.ordered # If it doesn't match any of the above, let's just play it safe and return False. return False # pragma: no cover diff --git a/pyproject.toml b/pyproject.toml index 17034d1a7b..f0a8de68f6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,8 +60,7 @@ core = [ "pandas", "polars", "pyarrow", - #TODO: reintroduce when fixing #1961 - # "pyarrow-stubs", + "pyarrow-stubs", ] extra = [ # heavier dependencies we don't necessarily need in every testing job "scikit-learn", @@ -201,6 +200,7 @@ omit = [ 'narwhals/typing.py', 'narwhals/stable/v1/typing.py', 'narwhals/this.py', + 'narwhals/_arrow/typing.py', # we can't run this in every environment that we measure coverage on due to upper-bound constraits 'narwhals/_ibis/*', # the latest pyspark (3.5) doesn't officially support Python 3.12 and 3.13 @@ -236,7 +236,6 @@ module = [ "modin.*", "numpy.*", "pandas.*", - "pyarrow.*", "pyspark.*", "sklearn.*", "sqlframe.*", diff --git a/tests/expr_and_series/is_nan_test.py b/tests/expr_and_series/is_nan_test.py index b319835000..d4065c2210 100644 --- a/tests/expr_and_series/is_nan_test.py +++ b/tests/expr_and_series/is_nan_test.py @@ -104,9 +104,11 @@ def test_nan_non_float(constructor: Constructor, request: pytest.FixtureRequest) data = {"a": ["x", "y"]} df = nw.from_native(constructor(data)) - exc = InvalidOperationError - if "pyarrow_table" in str(constructor): - exc = ArrowNotImplementedError + exc = ( + ArrowNotImplementedError + if "pyarrow_table" in str(constructor) + else InvalidOperationError + ) with pytest.raises(exc): df.select(nw.col("a").is_nan()).lazy().collect() @@ -120,9 +122,11 @@ def test_nan_non_float_series(constructor_eager: ConstructorEager) -> None: data = {"a": ["x", "y"]} df = nw.from_native(constructor_eager(data), eager_only=True) - exc = InvalidOperationError - if "pyarrow_table" in str(constructor_eager): - exc = ArrowNotImplementedError + exc = ( + ArrowNotImplementedError + if "pyarrow_table" in str(constructor_eager) + else InvalidOperationError + ) with pytest.raises(exc): df["a"].is_nan() diff --git a/tests/expr_and_series/rolling_var_test.py b/tests/expr_and_series/rolling_var_test.py index 86b47df330..c568e7aae2 100644 --- a/tests/expr_and_series/rolling_var_test.py +++ b/tests/expr_and_series/rolling_var_test.py @@ -1,6 +1,7 @@ from __future__ import annotations import random +from typing import TYPE_CHECKING from typing import Any import hypothesis.strategies as st @@ -16,6 +17,9 @@ from tests.utils import ConstructorEager from tests.utils import assert_equal_data +if TYPE_CHECKING: + from narwhals.typing import Frame + data = {"a": [1.0, 2.0, 1.0, 3.0, 1.0, 4.0, 1.0]} kwargs_and_expected = ( @@ -122,7 +126,7 @@ def test_rolling_var_hypothesis(center: bool, values: list[float]) -> None: # n .to_frame("a") ) - result = nw.from_native(pa.Table.from_pandas(df)).select( + result: Frame = nw.from_native(pa.Table.from_pandas(df)).select( nw.col("a").rolling_var( window_size, center=center, min_samples=min_samples, ddof=ddof ) diff --git a/tests/expr_and_series/str/to_datetime_test.py b/tests/expr_and_series/str/to_datetime_test.py index 24687af9ec..af72720e90 100644 --- a/tests/expr_and_series/str/to_datetime_test.py +++ b/tests/expr_and_series/str/to_datetime_test.py @@ -164,7 +164,7 @@ def test_pyarrow_infer_datetime_raise_invalid() -> None: NotImplementedError, match="Unable to infer datetime format, provided format is not supported.", ): - parse_datetime_format(pa.chunked_array([["2024-01-01", "abc"]])) + parse_datetime_format(pa.chunked_array([["2024-01-01", "abc"]])) # type: ignore[arg-type] @pytest.mark.parametrize( @@ -181,7 +181,7 @@ def test_pyarrow_infer_datetime_raise_not_unique( ValueError, match=f"Found multiple {duplicate} values while inferring datetime format.", ): - parse_datetime_format(pa.chunked_array([data])) + parse_datetime_format(pa.chunked_array([data])) # type: ignore[arg-type] @pytest.mark.parametrize("data", [["2024-01-01", "2024-12-01", "02-02-2024"]]) @@ -189,4 +189,4 @@ def test_pyarrow_infer_datetime_raise_inconsistent_date_fmt( data: list[str | None], ) -> None: with pytest.raises(ValueError, match="Unable to infer datetime format. "): - parse_datetime_format(pa.chunked_array([data])) + parse_datetime_format(pa.chunked_array([data])) # type: ignore[arg-type] diff --git a/tests/series_only/is_ordered_categorical_test.py b/tests/series_only/is_ordered_categorical_test.py index 7e7db5f238..547b951daa 100644 --- a/tests/series_only/is_ordered_categorical_test.py +++ b/tests/series_only/is_ordered_categorical_test.py @@ -1,6 +1,7 @@ from __future__ import annotations from typing import TYPE_CHECKING +from typing import Any import pandas as pd import polars as pl @@ -11,10 +12,12 @@ from tests.utils import PANDAS_VERSION if TYPE_CHECKING: + from narwhals.typing import IntoSeries from tests.utils import ConstructorEager def test_is_ordered_categorical() -> None: + s: IntoSeries | Any s = pl.Series(["a", "b"], dtype=pl.Categorical) assert nw.is_ordered_categorical(nw.from_native(s, series_only=True)) s = pl.Series(["a", "b"], dtype=pl.Categorical(ordering="lexical")) @@ -25,9 +28,8 @@ def test_is_ordered_categorical() -> None: assert nw.is_ordered_categorical(nw.from_native(s, series_only=True)) s = pd.Series(["a", "b"], dtype=pd.CategoricalDtype(ordered=False)) assert not nw.is_ordered_categorical(nw.from_native(s, series_only=True)) - s = pa.chunked_array( - [pa.array(["a", "b"], type=pa.dictionary(pa.int32(), pa.string()))] - ) + tp = pa.dictionary(pa.int32(), pa.string()) + s = pa.chunked_array([pa.array(["a", "b"], type=tp)], type=tp) assert not nw.is_ordered_categorical(nw.from_native(s, series_only=True)) @@ -51,7 +53,6 @@ def test_is_definitely_not_ordered_categorical( @pytest.mark.xfail(reason="https://github.com/apache/arrow/issues/41017") def test_is_ordered_categorical_pyarrow() -> None: - s = pa.chunked_array( - [pa.array(["a", "b"], type=pa.dictionary(pa.int32(), pa.string(), ordered=True))] - ) + tp = pa.dictionary(pa.int32(), pa.string(), ordered=True) + s = pa.chunked_array([pa.array(["a", "b"], type=tp)]) # type: ignore[list-item] assert nw.is_ordered_categorical(nw.from_native(s, series_only=True)) diff --git a/tests/translate/get_native_namespace_test.py b/tests/translate/get_native_namespace_test.py index 5ad3f94ebf..b269caef75 100644 --- a/tests/translate/get_native_namespace_test.py +++ b/tests/translate/get_native_namespace_test.py @@ -1,5 +1,7 @@ from __future__ import annotations +from typing import TYPE_CHECKING + import pandas as pd import polars as pl import pyarrow as pa @@ -7,9 +9,12 @@ import narwhals.stable.v1 as nw +if TYPE_CHECKING: + from narwhals.typing import Frame + def test_native_namespace() -> None: - df = nw.from_native(pl.DataFrame({"a": [1, 2, 3]})) + df: Frame = nw.from_native(pl.DataFrame({"a": [1, 2, 3]})) assert nw.get_native_namespace(df) is pl assert nw.get_native_namespace(df.to_native()) is pl assert nw.get_native_namespace(df.lazy().to_native()) is pl @@ -26,4 +31,4 @@ def test_native_namespace() -> None: def test_get_native_namespace_invalid() -> None: with pytest.raises(TypeError, match="Could not get native namespace"): - nw.get_native_namespace(1) + nw.get_native_namespace(1) # type: ignore[arg-type]