Skip to content

Commit

Permalink
TYP: Misc changes for pandas-stubs; use Protocol to avoid str in Sequ…
Browse files Browse the repository at this point in the history
…ence (#55263)

* TYP: misc changes for pandas-stubs test

* re-write changes from 47233 with SequenceNotStr

* pyupgrade
  • Loading branch information
twoertwein authored Sep 26, 2023
1 parent 89bd569 commit 1f16762
Show file tree
Hide file tree
Showing 10 changed files with 73 additions and 36 deletions.
43 changes: 38 additions & 5 deletions pandas/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
Type as type_t,
TypeVar,
Union,
overload,
)

import numpy as np
Expand Down Expand Up @@ -85,6 +86,8 @@
# Name "npt._ArrayLikeInt_co" is not defined [name-defined]
NumpySorter = Optional[npt._ArrayLikeInt_co] # type: ignore[name-defined]

from typing import SupportsIndex

if sys.version_info >= (3, 10):
from typing import TypeGuard # pyright: ignore[reportUnusedImport]
else:
Expand All @@ -109,18 +112,48 @@

# list-like

# Cannot use `Sequence` because a string is a sequence, and we don't want to
# accept that. Could refine if https://github.com/python/typing/issues/256 is
# resolved to differentiate between Sequence[str] and str
ListLike = Union[AnyArrayLike, list, tuple, range]
# from https://github.com/hauntsaninja/useful_types
# includes Sequence-like objects but excludes str and bytes
_T_co = TypeVar("_T_co", covariant=True)


class SequenceNotStr(Protocol[_T_co]):
@overload
def __getitem__(self, index: SupportsIndex, /) -> _T_co:
...

@overload
def __getitem__(self, index: slice, /) -> Sequence[_T_co]:
...

def __contains__(self, value: object, /) -> bool:
...

def __len__(self) -> int:
...

def __iter__(self) -> Iterator[_T_co]:
...

def index(self, value: Any, /, start: int = 0, stop: int = ...) -> int:
...

def count(self, value: Any, /) -> int:
...

def __reversed__(self) -> Iterator[_T_co]:
...


ListLike = Union[AnyArrayLike, SequenceNotStr, range]

# scalars

PythonScalar = Union[str, float, bool]
DatetimeLikeScalar = Union["Period", "Timestamp", "Timedelta"]
PandasScalar = Union["Period", "Timestamp", "Timedelta", "Interval"]
Scalar = Union[PythonScalar, PandasScalar, np.datetime64, np.timedelta64, date]
IntStrT = TypeVar("IntStrT", int, str)
IntStrT = TypeVar("IntStrT", bound=Union[int, str])


# timestamp and timedelta convertible types
Expand Down
13 changes: 7 additions & 6 deletions pandas/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@
Renamer,
Scalar,
Self,
SequenceNotStr,
SortKind,
StorageOptions,
Suffixes,
Expand Down Expand Up @@ -1187,7 +1188,7 @@ def to_string(
buf: None = ...,
columns: Axes | None = ...,
col_space: int | list[int] | dict[Hashable, int] | None = ...,
header: bool | list[str] = ...,
header: bool | SequenceNotStr[str] = ...,
index: bool = ...,
na_rep: str = ...,
formatters: fmt.FormattersType | None = ...,
Expand All @@ -1212,7 +1213,7 @@ def to_string(
buf: FilePath | WriteBuffer[str],
columns: Axes | None = ...,
col_space: int | list[int] | dict[Hashable, int] | None = ...,
header: bool | list[str] = ...,
header: bool | SequenceNotStr[str] = ...,
index: bool = ...,
na_rep: str = ...,
formatters: fmt.FormattersType | None = ...,
Expand Down Expand Up @@ -1250,7 +1251,7 @@ def to_string(
buf: FilePath | WriteBuffer[str] | None = None,
columns: Axes | None = None,
col_space: int | list[int] | dict[Hashable, int] | None = None,
header: bool | list[str] = True,
header: bool | SequenceNotStr[str] = True,
index: bool = True,
na_rep: str = "NaN",
formatters: fmt.FormattersType | None = None,
Expand Down Expand Up @@ -10563,9 +10564,9 @@ def merge(
self,
right: DataFrame | Series,
how: MergeHow = "inner",
on: IndexLabel | None = None,
left_on: IndexLabel | None = None,
right_on: IndexLabel | None = None,
on: IndexLabel | AnyArrayLike | None = None,
left_on: IndexLabel | AnyArrayLike | None = None,
right_on: IndexLabel | AnyArrayLike | None = None,
left_index: bool = False,
right_index: bool = False,
sort: bool = False,
Expand Down
7 changes: 4 additions & 3 deletions pandas/core/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
Renamer,
Scalar,
Self,
SequenceNotStr,
SortKind,
StorageOptions,
Suffixes,
Expand Down Expand Up @@ -3273,7 +3274,7 @@ def to_latex(
self,
buf: None = ...,
columns: Sequence[Hashable] | None = ...,
header: bool_t | list[str] = ...,
header: bool_t | SequenceNotStr[str] = ...,
index: bool_t = ...,
na_rep: str = ...,
formatters: FormattersType | None = ...,
Expand All @@ -3300,7 +3301,7 @@ def to_latex(
self,
buf: FilePath | WriteBuffer[str],
columns: Sequence[Hashable] | None = ...,
header: bool_t | list[str] = ...,
header: bool_t | SequenceNotStr[str] = ...,
index: bool_t = ...,
na_rep: str = ...,
formatters: FormattersType | None = ...,
Expand Down Expand Up @@ -3330,7 +3331,7 @@ def to_latex(
self,
buf: FilePath | WriteBuffer[str] | None = None,
columns: Sequence[Hashable] | None = None,
header: bool_t | list[str] = True,
header: bool_t | SequenceNotStr[str] = True,
index: bool_t = True,
na_rep: str = "NaN",
formatters: FormattersType | None = None,
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/methods/describe.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def describe_timestamp_as_categorical_1d(
names = ["count", "unique"]
objcounts = data.value_counts()
count_unique = len(objcounts[objcounts != 0])
result = [data.count(), count_unique]
result: list[float | Timestamp] = [data.count(), count_unique]
dtype = None
if count_unique > 0:
top, freq = objcounts.index[0], objcounts.iloc[0]
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -1541,7 +1541,7 @@ def count(self):

return result

def quantile(self, q: float | AnyArrayLike = 0.5, **kwargs):
def quantile(self, q: float | list[float] | AnyArrayLike = 0.5, **kwargs):
"""
Return value at the given quantile.
Expand Down
28 changes: 15 additions & 13 deletions pandas/core/reshape/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,9 @@ def merge(
left: DataFrame | Series,
right: DataFrame | Series,
how: MergeHow = "inner",
on: IndexLabel | None = None,
left_on: IndexLabel | None = None,
right_on: IndexLabel | None = None,
on: IndexLabel | AnyArrayLike | None = None,
left_on: IndexLabel | AnyArrayLike | None = None,
right_on: IndexLabel | AnyArrayLike | None = None,
left_index: bool = False,
right_index: bool = False,
sort: bool = False,
Expand Down Expand Up @@ -187,9 +187,9 @@ def merge(
def _cross_merge(
left: DataFrame,
right: DataFrame,
on: IndexLabel | None = None,
left_on: IndexLabel | None = None,
right_on: IndexLabel | None = None,
on: IndexLabel | AnyArrayLike | None = None,
left_on: IndexLabel | AnyArrayLike | None = None,
right_on: IndexLabel | AnyArrayLike | None = None,
left_index: bool = False,
right_index: bool = False,
sort: bool = False,
Expand Down Expand Up @@ -239,7 +239,9 @@ def _cross_merge(
return res


def _groupby_and_merge(by, left: DataFrame, right: DataFrame, merge_pieces):
def _groupby_and_merge(
by, left: DataFrame | Series, right: DataFrame | Series, merge_pieces
):
"""
groupby & merge; we are always performing a left-by type operation
Expand All @@ -255,7 +257,7 @@ def _groupby_and_merge(by, left: DataFrame, right: DataFrame, merge_pieces):
by = [by]

lby = left.groupby(by, sort=False)
rby: groupby.DataFrameGroupBy | None = None
rby: groupby.DataFrameGroupBy | groupby.SeriesGroupBy | None = None

# if we can groupby the rhs
# then we can get vastly better perf
Expand Down Expand Up @@ -295,8 +297,8 @@ def _groupby_and_merge(by, left: DataFrame, right: DataFrame, merge_pieces):


def merge_ordered(
left: DataFrame,
right: DataFrame,
left: DataFrame | Series,
right: DataFrame | Series,
on: IndexLabel | None = None,
left_on: IndexLabel | None = None,
right_on: IndexLabel | None = None,
Expand Down Expand Up @@ -737,9 +739,9 @@ def __init__(
left: DataFrame | Series,
right: DataFrame | Series,
how: MergeHow | Literal["asof"] = "inner",
on: IndexLabel | None = None,
left_on: IndexLabel | None = None,
right_on: IndexLabel | None = None,
on: IndexLabel | AnyArrayLike | None = None,
left_on: IndexLabel | AnyArrayLike | None = None,
right_on: IndexLabel | AnyArrayLike | None = None,
left_index: bool = False,
right_index: bool = False,
sort: bool = True,
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -2141,7 +2141,7 @@ def groupby(
# Statistics, overridden ndarray methods

# TODO: integrate bottleneck
def count(self):
def count(self) -> int:
"""
Return number of non-NA/null observations in the Series.
Expand Down
7 changes: 4 additions & 3 deletions pandas/io/formats/csvs.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import numpy as np

from pandas._libs import writers as libwriters
from pandas._typing import SequenceNotStr
from pandas.util._decorators import cache_readonly

from pandas.core.dtypes.generic import (
Expand Down Expand Up @@ -109,7 +110,7 @@ def decimal(self) -> str:
return self.fmt.decimal

@property
def header(self) -> bool | list[str]:
def header(self) -> bool | SequenceNotStr[str]:
return self.fmt.header

@property
Expand Down Expand Up @@ -213,7 +214,7 @@ def _need_to_save_header(self) -> bool:
return bool(self._has_aliases or self.header)

@property
def write_cols(self) -> Sequence[Hashable]:
def write_cols(self) -> SequenceNotStr[Hashable]:
if self._has_aliases:
assert not isinstance(self.header, bool)
if len(self.header) != len(self.cols):
Expand All @@ -224,7 +225,7 @@ def write_cols(self) -> Sequence[Hashable]:
else:
# self.cols is an ndarray derived from Index._format_native_types,
# so its entries are strings, i.e. hashable
return cast(Sequence[Hashable], self.cols)
return cast(SequenceNotStr[Hashable], self.cols)

@property
def encoded_labels(self) -> list[Hashable]:
Expand Down
3 changes: 2 additions & 1 deletion pandas/io/formats/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@
FloatFormatType,
FormattersType,
IndexLabel,
SequenceNotStr,
StorageOptions,
WriteBuffer,
)
Expand Down Expand Up @@ -566,7 +567,7 @@ def __init__(
frame: DataFrame,
columns: Axes | None = None,
col_space: ColspaceArgType | None = None,
header: bool | list[str] = True,
header: bool | SequenceNotStr[str] = True,
index: bool = True,
na_rep: str = "NaN",
formatters: FormattersType | None = None,
Expand Down
2 changes: 0 additions & 2 deletions pandas/tests/io/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -3161,8 +3161,6 @@ def dtype_backend_data() -> DataFrame:
@pytest.fixture
def dtype_backend_expected():
def func(storage, dtype_backend, conn_name):
string_array: StringArray | ArrowStringArray
string_array_na: StringArray | ArrowStringArray
if storage == "python":
string_array = StringArray(np.array(["a", "b", "c"], dtype=np.object_))
string_array_na = StringArray(np.array(["a", "b", pd.NA], dtype=np.object_))
Expand Down

0 comments on commit 1f16762

Please sign in to comment.