diff --git a/pyarrow-stubs/_stubs_typing.pyi b/pyarrow-stubs/_stubs_typing.pyi index dbcb2c5..6c25aa2 100644 --- a/pyarrow-stubs/_stubs_typing.pyi +++ b/pyarrow-stubs/_stubs_typing.pyi @@ -1,5 +1,6 @@ import datetime as dt +from collections.abc import Sequence from decimal import Decimal from typing import Any, Collection, Literal, Protocol, TypeAlias, TypeVar @@ -7,7 +8,7 @@ import numpy as np from numpy.typing import NDArray -from .__lib_pxi.array import BooleanArray, IntegerArray +from .compute import BooleanArray, IntegerArray ArrayLike: TypeAlias = Any ScalarLike: TypeAlias = Any @@ -27,8 +28,8 @@ Compression: TypeAlias = Literal[ ] NullEncoding: TypeAlias = Literal["mask", "encode"] NullSelectionBehavior: TypeAlias = Literal["drop", "emit_null"] -Mask: TypeAlias = list[bool | None] | NDArray[np.bool_] | BooleanArray -Indices: TypeAlias = list[int] | NDArray[np.integer] | IntegerArray +Mask: TypeAlias = Sequence[bool | None] | NDArray[np.bool_] | BooleanArray +Indices: TypeAlias = Sequence[int] | NDArray[np.integer] | IntegerArray PyScalar: TypeAlias = ( bool | int | float | Decimal | str | bytes | dt.date | dt.datetime | dt.time | dt.timedelta ) diff --git a/pyarrow-stubs/compute.pyi b/pyarrow-stubs/compute.pyi index d764795..db288d4 100644 --- a/pyarrow-stubs/compute.pyi +++ b/pyarrow-stubs/compute.pyi @@ -94,27 +94,38 @@ def _clone_signature(f: Callable[_P, _R]) -> Callable[_P, _R]: ... # ============= compute functions ============= _DataTypeT = TypeVar("_DataTypeT", bound=lib.DataType) -NumericScalar: TypeAlias = ( +_Scalar_CoT = TypeVar("_Scalar_CoT", bound=lib.Scalar, covariant=True) +_ScalarT = TypeVar("_ScalarT", bound=lib.Scalar) +_ArrayT = TypeVar("_ArrayT", bound=lib.Array | lib.ChunkedArray) +_ScalarOrArrayT = TypeVar("_ScalarOrArrayT", bound=lib.Array | lib.Scalar | lib.ChunkedArray) +ArrayOrChunkedArray: TypeAlias = lib.Array[_Scalar_CoT] | lib.ChunkedArray[_Scalar_CoT] +ScalarOrArray: TypeAlias = ArrayOrChunkedArray[_Scalar_CoT] | _Scalar_CoT + +SignedIntegerScalar: TypeAlias = ( lib.Scalar[lib.Int8Type] | lib.Scalar[lib.Int16Type] | lib.Scalar[lib.Int32Type] | lib.Scalar[lib.Int64Type] - | lib.Scalar[lib.Uint8Type] +) +UnsignedIntegerScalar: TypeAlias = ( + lib.Scalar[lib.Uint8Type] | lib.Scalar[lib.Uint16Type] | lib.Scalar[lib.Uint32Type] | lib.Scalar[lib.Uint64Type] - | lib.Scalar[lib.Float16Type] - | lib.Scalar[lib.Float32Type] - | lib.Scalar[lib.Float64Type] - | lib.Scalar[lib.Decimal128Type] - | lib.Scalar[lib.Decimal256Type] ) +IntegerScalar: TypeAlias = SignedIntegerScalar | UnsignedIntegerScalar +FloatScalar: TypeAlias = ( + lib.Scalar[lib.Float16Type] | lib.Scalar[lib.Float32Type] | lib.Scalar[lib.Float64Type] +) +DecimalScalar: TypeAlias = lib.Scalar[lib.Decimal128Type] | lib.Scalar[lib.Decimal256Type] +NumericScalar: TypeAlias = IntegerScalar | FloatScalar | DecimalScalar BinaryScalar: TypeAlias = ( lib.Scalar[lib.BinaryType] | lib.Scalar[lib.LargeBinaryType] | lib.Scalar[lib.FixedSizeBinaryType] ) StringScalar: TypeAlias = lib.Scalar[lib.StringType] | lib.Scalar[lib.LargeStringType] +StringOrBinaryScalar: TypeAlias = StringScalar | BinaryScalar ListScalar: TypeAlias = ( lib.ListScalar[_DataTypeT] | lib.LargeListScalar[_DataTypeT] @@ -131,73 +142,35 @@ TemporalScalar: TypeAlias = ( | lib.DurationScalar | lib.MonthDayNanoIntervalScalar ) -_NumericScalarT = TypeVar("_NumericScalarT", bound=NumericScalar) NumericOrDurationScalar: TypeAlias = NumericScalar | lib.DurationScalar -_NumericOrDurationT = TypeVar("_NumericOrDurationT", bound=NumericOrDurationScalar) NumericOrTemporalScalar: TypeAlias = NumericScalar | TemporalScalar + _NumericOrTemporalT = TypeVar("_NumericOrTemporalT", bound=NumericOrTemporalScalar) -NumericArray: TypeAlias = lib.NumericArray[_ScalarT] | lib.ChunkedArray[_ScalarT] -_NumericArrayT = TypeVar("_NumericArrayT", bound=lib.NumericArray) -NumericOrDurationArray: TypeAlias = ( - lib.NumericArray | lib.Array[lib.DurationScalar] | lib.ChunkedArray -) +NumericArray: TypeAlias = ArrayOrChunkedArray[_NumericScalarT] +_NumericArrayT = TypeVar("_NumericArrayT", bound=NumericArray) +_NumericScalarT = TypeVar("_NumericScalarT", bound=NumericScalar) +_NumericOrDurationT = TypeVar("_NumericOrDurationT", bound=NumericOrDurationScalar) +NumericOrDurationArray: TypeAlias = ArrayOrChunkedArray[NumericOrDurationScalar] _NumericOrDurationArrayT = TypeVar("_NumericOrDurationArrayT", bound=NumericOrDurationArray) -NumericOrTemporalArray: TypeAlias = ( - lib.NumericArray | lib.Array[TemporalScalar] | lib.ChunkedArray[TemporalScalar] -) +NumericOrTemporalArray: TypeAlias = ArrayOrChunkedArray[_NumericOrTemporalT] _NumericOrTemporalArrayT = TypeVar("_NumericOrTemporalArrayT", bound=NumericOrTemporalArray) -BooleanArray: TypeAlias = lib.BooleanArray | lib.ChunkedArray[lib.BooleanScalar] -FloatScalar: TypeAlias = lib.Scalar[lib.Float32Type] | lib.Scalar[lib.Float64Type] -DecimalScalar: TypeAlias = lib.Scalar[lib.Decimal128Type] | lib.Scalar[lib.Decimal256Type] +BooleanArray: TypeAlias = ArrayOrChunkedArray[lib.BooleanScalar] +IntegerArray: TypeAlias = ArrayOrChunkedArray[IntegerScalar] _FloatScalarT = TypeVar("_FloatScalarT", bound=FloatScalar) -FloatArray: TypeAlias = ( - lib.NumericArray[lib.FloatScalar] - | lib.NumericArray[lib.DoubleScalar] - | lib.ChunkedArray[lib.FloatScalar] - | lib.ChunkedArray[lib.DoubleScalar] -) - +FloatArray: TypeAlias = ArrayOrChunkedArray[FloatScalar] _FloatArrayT = TypeVar("_FloatArrayT", bound=FloatArray) _StringScalarT = TypeVar("_StringScalarT", bound=StringScalar) -StringArray: TypeAlias = ( - lib.StringArray - | lib.LargeStringArray - | lib.ChunkedArray[lib.StringScalar] - | lib.ChunkedArray[lib.LargeStringScalar] -) +StringArray: TypeAlias = ArrayOrChunkedArray[StringScalar] _StringArrayT = TypeVar("_StringArrayT", bound=StringArray) _BinaryScalarT = TypeVar("_BinaryScalarT", bound=BinaryScalar) -BinaryArray: TypeAlias = ( - lib.BinaryArray - | lib.LargeBinaryArray - | lib.ChunkedArray[lib.BinaryScalar] - | lib.ChunkedArray[lib.LargeBinaryScalar] -) +BinaryArray: TypeAlias = ArrayOrChunkedArray[BinaryScalar] _BinaryArrayT = TypeVar("_BinaryArrayT", bound=BinaryArray) -StringOrBinaryScalar: TypeAlias = StringScalar | BinaryScalar _StringOrBinaryScalarT = TypeVar("_StringOrBinaryScalarT", bound=StringOrBinaryScalar) StringOrBinaryArray: TypeAlias = StringArray | BinaryArray _StringOrBinaryArrayT = TypeVar("_StringOrBinaryArrayT", bound=StringOrBinaryArray) _TemporalScalarT = TypeVar("_TemporalScalarT", bound=TemporalScalar) -TemporalArray: TypeAlias = ( - lib.Date32Array - | lib.Date64Array - | lib.Time32Array - | lib.Time64Array - | lib.TimestampArray - | lib.DurationArray - | lib.MonthDayNanoIntervalArray - | lib.ChunkedArray[lib.Date32Scalar] - | lib.ChunkedArray[lib.Date64Scalar] - | lib.ChunkedArray[lib.Time32Scalar] - | lib.ChunkedArray[lib.Time64Scalar] - | lib.ChunkedArray[lib.DurationScalar] - | lib.ChunkedArray[lib.MonthDayNanoIntervalScalar] -) +TemporalArray: TypeAlias = ArrayOrChunkedArray[TemporalScalar] _TemporalArrayT = TypeVar("_TemporalArrayT", bound=TemporalArray) -_ScalarT = TypeVar("_ScalarT", bound=lib.Scalar) -_ArrayT = TypeVar("_ArrayT", bound=lib.Array) -_ScalarOrArrayT = TypeVar("_ScalarOrArrayT", bound=lib.Array | lib.Scalar) # =============================== 1. Aggregation =============================== # ========================= 1.1 functions ========================= @@ -940,18 +913,11 @@ not_equal = _clone_signature(equal) @overload def max_element_wise( - *args: _ScalarT, - skip_nulls: bool = True, - options: ElementWiseAggregateOptions | None = None, - memory_pool: lib.MemoryPool | None = None, -) -> _ScalarT: ... -@overload -def max_element_wise( - *args: _ArrayT, + *args: ScalarOrArray[_Scalar_CoT], skip_nulls: bool = True, options: ElementWiseAggregateOptions | None = None, memory_pool: lib.MemoryPool | None = None, -) -> _ArrayT: ... +) -> _Scalar_CoT: ... @overload def max_element_wise( *args: Expression, @@ -960,7 +926,7 @@ def max_element_wise( memory_pool: lib.MemoryPool | None = None, ) -> Expression: ... -min_element_wise = _clone_signature(equal) +min_element_wise = _clone_signature(max_element_wise) # ========================= 2.6 Logical functions ========================= @overload