Skip to content

Commit

Permalink
TYP: use overload to refine return type of set_axis (#40197)
Browse files Browse the repository at this point in the history
* try typing set_axis

* fixup overloads

* remove return type from base definition

* searching in the sun for another overload

* bool

* allow defaults in bool case

* type non-overloaded arguments

Co-authored-by: Simon Hawkins <[email protected]>
  • Loading branch information
MarcoGorelli and simonjayhawkins authored Mar 14, 2021
1 parent 1cc40fa commit 015c0c0
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 0 deletions.
20 changes: 20 additions & 0 deletions pandas/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -4541,6 +4541,26 @@ def align(
broadcast_axis=broadcast_axis,
)

@overload
def set_axis(
self, labels, axis: Axis = ..., inplace: Literal[False] = ...
) -> DataFrame:
...

@overload
def set_axis(self, labels, axis: Axis, inplace: Literal[True]) -> None:
...

@overload
def set_axis(self, labels, *, inplace: Literal[True]) -> None:
...

@overload
def set_axis(
self, labels, axis: Axis = ..., inplace: bool = ...
) -> Optional[DataFrame]:
...

@Appender(
"""
Examples
Expand Down
25 changes: 25 additions & 0 deletions pandas/core/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
Type,
Union,
cast,
overload,
)
import warnings
import weakref
Expand Down Expand Up @@ -162,6 +163,8 @@
from pandas.io.formats.printing import pprint_thing

if TYPE_CHECKING:
from typing import Literal

from pandas._libs.tslibs import BaseOffset

from pandas.core.frame import DataFrame
Expand Down Expand Up @@ -682,6 +685,28 @@ def _obj_with_exclusions(self: FrameOrSeries) -> FrameOrSeries:
""" internal compat with SelectionMixin """
return self

@overload
def set_axis(
self: FrameOrSeries, labels, axis: Axis = ..., inplace: Literal[False] = ...
) -> FrameOrSeries:
...

@overload
def set_axis(
self: FrameOrSeries, labels, axis: Axis, inplace: Literal[True]
) -> None:
...

@overload
def set_axis(self: FrameOrSeries, labels, *, inplace: Literal[True]) -> None:
...

@overload
def set_axis(
self: FrameOrSeries, labels, axis: Axis = ..., inplace: bool = ...
) -> Optional[FrameOrSeries]:
...

def set_axis(self, labels, axis: Axis = 0, inplace: bool = False):
"""
Assign desired index to given axis.
Expand Down
23 changes: 23 additions & 0 deletions pandas/core/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
Type,
Union,
cast,
overload,
)
import warnings

Expand Down Expand Up @@ -142,6 +143,8 @@
import pandas.plotting

if TYPE_CHECKING:
from typing import Literal

from pandas._typing import (
TimedeltaConvertibleTypes,
TimestampConvertibleTypes,
Expand Down Expand Up @@ -4342,6 +4345,26 @@ def rename(
else:
return self._set_name(index, inplace=inplace)

@overload
def set_axis(
self, labels, axis: Axis = ..., inplace: Literal[False] = ...
) -> Series:
...

@overload
def set_axis(self, labels, axis: Axis, inplace: Literal[True]) -> None:
...

@overload
def set_axis(self, labels, *, inplace: Literal[True]) -> None:
...

@overload
def set_axis(
self, labels, axis: Axis = ..., inplace: bool = ...
) -> Optional[Series]:
...

@Appender(
"""
Examples
Expand Down

0 comments on commit 015c0c0

Please sign in to comment.