diff --git a/narwhals/_arrow/series.py b/narwhals/_arrow/series.py index a7ffd2f027..5230cdb410 100644 --- a/narwhals/_arrow/series.py +++ b/narwhals/_arrow/series.py @@ -271,14 +271,14 @@ def __truediv__(self: Self, other: Any) -> Self: if not isinstance(other, (pa.Array, pa.ChunkedArray)): # scalar other = lit(other) - return self._with_native(pc.divide(*cast_for_truediv(ser, other))) # type: ignore[arg-type, type-var] + return self._with_native(pc.divide(*cast_for_truediv(ser, other))) # type: ignore[type-var] def __rtruediv__(self: Self, other: Any) -> Self: ser, other = extract_native(self, other) if not isinstance(other, (pa.Array, pa.ChunkedArray)): # scalar other = lit(other) if not isinstance(other, pa.Scalar) else other - return self._with_native(pc.divide(*cast_for_truediv(other, ser))) # type: ignore[arg-type, type-var] + return self._with_native(pc.divide(*cast_for_truediv(other, ser))) # type: ignore[type-var] def __mod__(self: Self, other: Any) -> Self: floor_div = (self // other).native @@ -596,11 +596,11 @@ def value_counts( counts = cast("ArrowChunkedArray", val_counts.field("counts")) if normalize: - arrays = [values, pc.divide(*cast_for_truediv(counts, pc.sum(counts)))] # type: ignore[type-var] + arrays = [values, pc.divide(*cast_for_truediv(counts, pc.sum(counts)))] else: arrays = [values, counts] - val_count = pa.Table.from_arrays(arrays, names=[index_name_, value_name_]) # type: ignore[arg-type] + val_count = pa.Table.from_arrays(arrays, names=[index_name_, value_name_]) if sort: val_count = val_count.sort_by([(value_name_, "descending")]) diff --git a/narwhals/_arrow/typing.py b/narwhals/_arrow/typing.py index 2382be617d..1fcf4939b8 100644 --- a/narwhals/_arrow/typing.py +++ b/narwhals/_arrow/typing.py @@ -34,6 +34,11 @@ ArrowChunkedArray: TypeAlias = pa.ChunkedArray[Any] ArrowArray: TypeAlias = pa.Array[Any] + ArrayAny: TypeAlias = "ArrowArray | ArrowChunkedArray" + ScalarAny: TypeAlias = pa.Scalar[Any] + ArrayOrScalarAny: TypeAlias = "ArrayAny | ScalarAny" + ArrayOrScalarT1 = TypeVar("ArrayOrScalarT1", ArrowArray, ArrowChunkedArray, ScalarAny) + ArrayOrScalarT2 = TypeVar("ArrayOrScalarT2", ArrowArray, ArrowChunkedArray, ScalarAny) _AsPyType = TypeVar("_AsPyType") class _BasicDataType(pa.DataType, Generic[_AsPyType]): ... diff --git a/narwhals/_arrow/utils.py b/narwhals/_arrow/utils.py index 87a739455b..380388d51d 100644 --- a/narwhals/_arrow/utils.py +++ b/narwhals/_arrow/utils.py @@ -26,6 +26,9 @@ from typing_extensions import TypeIs from narwhals._arrow.series import ArrowSeries + from narwhals._arrow.typing import ArrayAny + from narwhals._arrow.typing import ArrayOrScalarT1 + from narwhals._arrow.typing import ArrayOrScalarT2 from narwhals._arrow.typing import ArrowArray from narwhals._arrow.typing import ArrowChunkedArray from narwhals.dtypes import DType @@ -35,7 +38,6 @@ # NOTE: stubs don't allow for `ChunkedArray[StructArray]` # Intended to represent the `.chunks` property storing `list[pa.StructArray]` ChunkedArrayStructArray: TypeAlias = ArrowChunkedArray - ArrayAny: TypeAlias = "ArrowArray | ArrowChunkedArray" _T = TypeVar("_T") @@ -356,12 +358,8 @@ def floordiv_compat(left: Any, right: Any) -> Any: def cast_for_truediv( - arrow_array: ArrowChunkedArray | pa.Scalar[Any], - pa_object: ArrowChunkedArray | ArrowArray | pa.Scalar[Any], -) -> tuple[ - ArrowChunkedArray | pa.Scalar[Any], - ArrowChunkedArray | ArrowArray | pa.Scalar[Any], -]: + arrow_array: ArrayOrScalarT1, pa_object: ArrayOrScalarT2 +) -> tuple[ArrayOrScalarT1, ArrayOrScalarT2]: # Lifted from: # https://github.com/pandas-dev/pandas/blob/262fcfbffcee5c3116e86a951d8b693f90411e68/pandas/core/arrays/arrow/array.py#L108-L122 # Ensure int / int -> float mirroring Python/Numpy behavior @@ -369,6 +367,7 @@ def cast_for_truediv( if pa.types.is_integer(arrow_array.type) and pa.types.is_integer(pa_object.type): # GH: 56645. # noqa: ERA001 # https://github.com/apache/arrow/issues/35563 + # NOTE: `pyarrow==11.*` doesn't allow keywords in `Array.cast` return pc.cast(arrow_array, pa.float64(), safe=False), pc.cast( pa_object, pa.float64(), safe=False )