11from __future__ import annotations
22
33from typing import (
4+ TYPE_CHECKING ,
45 Any ,
6+ Callable ,
7+ Hashable ,
8+ Iterable ,
9+ Literal ,
510 MutableMapping ,
11+ Sequence ,
12+ TypeVar ,
13+ overload ,
614)
715
816from pandas .compat ._optional import import_optional_dependency
1220 is_list_like ,
1321)
1422
15- _writers : MutableMapping [str , str ] = {}
23+ if TYPE_CHECKING :
24+ from pandas .io .excel ._base import ExcelWriter
1625
26+ ExcelWriter_t = type [ExcelWriter ]
27+ usecols_func = TypeVar ("usecols_func" , bound = Callable [[Hashable ], object ])
1728
18- def register_writer (klass ):
29+ _writers : MutableMapping [str , ExcelWriter_t ] = {}
30+
31+
32+ def register_writer (klass : ExcelWriter_t ) -> None :
1933 """
2034 Add engine to the excel writer registry.io.excel.
2135
@@ -28,10 +42,12 @@ def register_writer(klass):
2842 if not callable (klass ):
2943 raise ValueError ("Can only register callables as engines" )
3044 engine_name = klass .engine
45+ # for mypy
46+ assert isinstance (engine_name , str )
3147 _writers [engine_name ] = klass
3248
3349
34- def get_default_engine (ext , mode = "reader" ):
50+ def get_default_engine (ext : str , mode : Literal [ "reader" , "writer" ] = "reader" ) -> str :
3551 """
3652 Return the default reader/writer for the given extension.
3753
@@ -73,7 +89,7 @@ def get_default_engine(ext, mode="reader"):
7389 return _default_readers [ext ]
7490
7591
76- def get_writer (engine_name ) :
92+ def get_writer (engine_name : str ) -> ExcelWriter_t :
7793 try :
7894 return _writers [engine_name ]
7995 except KeyError as err :
@@ -145,7 +161,29 @@ def _range2cols(areas: str) -> list[int]:
145161 return cols
146162
147163
148- def maybe_convert_usecols (usecols ):
164+ @overload
165+ def maybe_convert_usecols (usecols : str | list [int ]) -> list [int ]:
166+ ...
167+
168+
169+ @overload
170+ def maybe_convert_usecols (usecols : list [str ]) -> list [str ]:
171+ ...
172+
173+
174+ @overload
175+ def maybe_convert_usecols (usecols : usecols_func ) -> usecols_func :
176+ ...
177+
178+
179+ @overload
180+ def maybe_convert_usecols (usecols : None ) -> None :
181+ ...
182+
183+
184+ def maybe_convert_usecols (
185+ usecols : str | list [int ] | list [str ] | usecols_func | None ,
186+ ) -> None | list [int ] | list [str ] | usecols_func :
149187 """
150188 Convert `usecols` into a compatible format for parsing in `parsers.py`.
151189
@@ -174,7 +212,17 @@ def maybe_convert_usecols(usecols):
174212 return usecols
175213
176214
177- def validate_freeze_panes (freeze_panes ):
215+ @overload
216+ def validate_freeze_panes (freeze_panes : tuple [int , int ]) -> Literal [True ]:
217+ ...
218+
219+
220+ @overload
221+ def validate_freeze_panes (freeze_panes : None ) -> Literal [False ]:
222+ ...
223+
224+
225+ def validate_freeze_panes (freeze_panes : tuple [int , int ] | None ) -> bool :
178226 if freeze_panes is not None :
179227 if len (freeze_panes ) == 2 and all (
180228 isinstance (item , int ) for item in freeze_panes
@@ -191,7 +239,9 @@ def validate_freeze_panes(freeze_panes):
191239 return False
192240
193241
194- def fill_mi_header (row , control_row ):
242+ def fill_mi_header (
243+ row : list [Hashable ], control_row : list [bool ]
244+ ) -> tuple [list [Hashable ], list [bool ]]:
195245 """
196246 Forward fill blank entries in row but only inside the same parent index.
197247
@@ -224,7 +274,9 @@ def fill_mi_header(row, control_row):
224274 return row , control_row
225275
226276
227- def pop_header_name (row , index_col ):
277+ def pop_header_name (
278+ row : list [Hashable ], index_col : int | Sequence [int ]
279+ ) -> tuple [Hashable | None , list [Hashable ]]:
228280 """
229281 Pop the header name for MultiIndex parsing.
230282
@@ -243,7 +295,12 @@ def pop_header_name(row, index_col):
243295 The original data row with the header name removed.
244296 """
245297 # Pop out header name and fill w/blank.
246- i = index_col if not is_list_like (index_col ) else max (index_col )
298+ if is_list_like (index_col ):
299+ assert isinstance (index_col , Iterable )
300+ i = max (index_col )
301+ else :
302+ assert not isinstance (index_col , Iterable )
303+ i = index_col
247304
248305 header_name = row [i ]
249306 header_name = None if header_name == "" else header_name
0 commit comments