11# being a bit too dynamic
22from math import ceil
3- from typing import TYPE_CHECKING , Tuple
3+ from typing import TYPE_CHECKING , Iterable , List , Sequence , Tuple , Union
44import warnings
55
66import matplotlib .table
1515from pandas .plotting ._matplotlib import compat
1616
1717if TYPE_CHECKING :
18+ from matplotlib .axes import Axes
19+ from matplotlib .axis import Axis
20+ from matplotlib .lines import Line2D # noqa:F401
1821 from matplotlib .table import Table
1922
2023
21- def format_date_labels (ax , rot ):
24+ def format_date_labels (ax : "Axes" , rot ):
2225 # mini version of autofmt_xdate
2326 for label in ax .get_xticklabels ():
2427 label .set_ha ("right" )
@@ -278,7 +281,7 @@ def _subplots(
278281 return fig , axes
279282
280283
281- def _remove_labels_from_axis (axis ):
284+ def _remove_labels_from_axis (axis : "Axis" ):
282285 for t in axis .get_majorticklabels ():
283286 t .set_visible (False )
284287
@@ -294,7 +297,15 @@ def _remove_labels_from_axis(axis):
294297 axis .get_label ().set_visible (False )
295298
296299
297- def _handle_shared_axes (axarr , nplots , naxes , nrows , ncols , sharex , sharey ):
300+ def _handle_shared_axes (
301+ axarr : Iterable ["Axes" ],
302+ nplots : int ,
303+ naxes : int ,
304+ nrows : int ,
305+ ncols : int ,
306+ sharex : bool ,
307+ sharey : bool ,
308+ ):
298309 if nplots > 1 :
299310 if compat ._mpl_ge_3_2_0 ():
300311 row_num = lambda x : x .get_subplotspec ().rowspan .start
@@ -340,15 +351,21 @@ def _handle_shared_axes(axarr, nplots, naxes, nrows, ncols, sharex, sharey):
340351 _remove_labels_from_axis (ax .yaxis )
341352
342353
343- def _flatten (axes ) :
354+ def _flatten (axes : Union [ "Axes" , Sequence [ "Axes" ]]) -> Sequence [ "Axes" ] :
344355 if not is_list_like (axes ):
345356 return np .array ([axes ])
346357 elif isinstance (axes , (np .ndarray , ABCIndexClass )):
347358 return axes .ravel ()
348359 return np .array (axes )
349360
350361
351- def _set_ticks_props (axes , xlabelsize = None , xrot = None , ylabelsize = None , yrot = None ):
362+ def _set_ticks_props (
363+ axes : Union ["Axes" , Sequence ["Axes" ]],
364+ xlabelsize = None ,
365+ xrot = None ,
366+ ylabelsize = None ,
367+ yrot = None ,
368+ ):
352369 import matplotlib .pyplot as plt
353370
354371 for ax in _flatten (axes ):
@@ -363,7 +380,7 @@ def _set_ticks_props(axes, xlabelsize=None, xrot=None, ylabelsize=None, yrot=Non
363380 return axes
364381
365382
366- def _get_all_lines (ax ) :
383+ def _get_all_lines (ax : "Axes" ) -> List [ "Line2D" ] :
367384 lines = ax .get_lines ()
368385
369386 if hasattr (ax , "right_ax" ):
@@ -375,7 +392,7 @@ def _get_all_lines(ax):
375392 return lines
376393
377394
378- def _get_xlim (lines ) -> Tuple [float , float ]:
395+ def _get_xlim (lines : Iterable [ "Line2D" ] ) -> Tuple [float , float ]:
379396 left , right = np .inf , - np .inf
380397 for l in lines :
381398 x = l .get_xdata (orig = False )
0 commit comments