Skip to content

Commit

Permalink
Merge pull request #2109 from pallets/more-typing
Browse files Browse the repository at this point in the history
enable more mypy checks
  • Loading branch information
davidism authored May 11, 2021
2 parents 30ce9ee + 2dfb6dc commit d6f89ee
Show file tree
Hide file tree
Showing 36 changed files with 533 additions and 361 deletions.
14 changes: 10 additions & 4 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -90,18 +90,24 @@ python_version = 3.6
allow_redefinition = True
disallow_subclassing_any = True
# disallow_untyped_calls = True
# disallow_untyped_defs = True
# disallow_incomplete_defs = True
disallow_untyped_defs = True
disallow_incomplete_defs = True
no_implicit_optional = True
local_partial_types = True
# no_implicit_reexport = True
no_implicit_reexport = True
strict_equality = True
warn_redundant_casts = True
warn_unused_configs = True
warn_unused_ignores = True
# warn_return_any = True
warn_return_any = True
# warn_unreachable = True

[mypy-werkzeug]
no_implicit_reexport = False

[mypy-werkzeug.wrappers]
no_implicit_reexport = False

[mypy-colorama.*]
ignore_missing_imports = True

Expand Down
48 changes: 30 additions & 18 deletions src/werkzeug/_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from weakref import WeakKeyDictionary

if t.TYPE_CHECKING:
from wsgiref.types import StartResponse
from wsgiref.types import WSGIApplication
from wsgiref.types import WSGIEnvironment
from .wrappers.request import Request # noqa: F401
Expand Down Expand Up @@ -48,10 +49,10 @@


class _Missing:
def __repr__(self):
def __repr__(self) -> str:
return "no value"

def __reduce__(self):
def __reduce__(self) -> str:
return "_missing"


Expand All @@ -68,7 +69,7 @@ def _make_encode_wrapper(reference: bytes) -> t.Callable[[str], bytes]:
...


def _make_encode_wrapper(reference):
def _make_encode_wrapper(reference: t.AnyStr) -> t.Callable[[str], t.AnyStr]:
"""Create a function that will be called with a string argument. If
the reference is bytes, values will be encoded to bytes.
"""
Expand Down Expand Up @@ -127,7 +128,12 @@ def _to_str(
...


def _to_str(x, charset=_default_encoding, errors="strict", allow_none_charset=False):
def _to_str(
x: t.Optional[t.Any],
charset: t.Optional[str] = _default_encoding,
errors: str = "strict",
allow_none_charset: bool = False,
) -> t.Optional[t.Union[str, bytes]]:
if x is None or isinstance(x, str):
return x

Expand All @@ -138,7 +144,7 @@ def _to_str(x, charset=_default_encoding, errors="strict", allow_none_charset=Fa
if allow_none_charset:
return x

return x.decode(charset, errors)
return x.decode(charset, errors) # type: ignore


def _wsgi_decoding_dance(
Expand Down Expand Up @@ -186,7 +192,7 @@ def _has_level_handler(logger: logging.Logger) -> bool:
class _ColorStreamHandler(logging.StreamHandler):
"""On Windows, wrap stream with Colorama for ANSI style support."""

def __init__(self):
def __init__(self) -> None:
try:
import colorama
except ImportError:
Expand All @@ -197,7 +203,7 @@ def __init__(self):
super().__init__(stream)


def _log(type: str, message: str, *args, **kwargs) -> None:
def _log(type: str, message: str, *args: t.Any, **kwargs: t.Any) -> None:
"""Log a message to the 'werkzeug' logger.
The logger is created the first time it is needed. If there is no
Expand All @@ -219,7 +225,7 @@ def _log(type: str, message: str, *args, **kwargs) -> None:
getattr(_logger, type)(message.rstrip(), *args, **kwargs)


def _parse_signature(func):
def _parse_signature(func): # type: ignore
"""Return a signature object for the function.
.. deprecated:: 2.0
Expand Down Expand Up @@ -251,7 +257,7 @@ def _parse_signature(func):
arguments.append(param)
arguments = tuple(arguments)

def parse(args, kwargs):
def parse(args, kwargs): # type: ignore
new_args = []
missing = []
extra = {}
Expand Down Expand Up @@ -306,7 +312,7 @@ def _dt_as_utc(dt: datetime) -> datetime:
...


def _dt_as_utc(dt):
def _dt_as_utc(dt: t.Optional[datetime]) -> t.Optional[datetime]:
if dt is None:
return dt

Expand Down Expand Up @@ -356,24 +362,26 @@ def __get__(
def __get__(self, instance: t.Any, owner: type) -> _TAccessorValue:
...

def __get__(self, instance, owner):
def __get__(
self, instance: t.Optional[t.Any], owner: type
) -> t.Union[_TAccessorValue, "_DictAccessorProperty[_TAccessorValue]"]:
if instance is None:
return self

storage = self.lookup(instance)

if self.name not in storage:
return self.default
return self.default # type: ignore

value = storage[self.name]

if self.load_func is not None:
try:
return self.load_func(value)
except (ValueError, TypeError):
return self.default
return self.default # type: ignore

return value
return value # type: ignore

def __set__(self, instance: t.Any, value: _TAccessorValue) -> None:
if self.read_only:
Expand Down Expand Up @@ -513,7 +521,7 @@ def _make_cookie_domain(domain: str) -> bytes:
...


def _make_cookie_domain(domain):
def _make_cookie_domain(domain: t.Optional[str]) -> t.Optional[bytes]:
if domain is None:
return None
domain = _encode_idna(domain)
Expand All @@ -533,7 +541,7 @@ def _make_cookie_domain(domain):
def _easteregg(app: t.Optional["WSGIApplication"] = None) -> "WSGIApplication":
"""Like the name says. But who knows how it works?"""

def bzzzzzzz(gyver):
def bzzzzzzz(gyver: bytes) -> str:
import base64
import zlib

Expand Down Expand Up @@ -579,8 +587,12 @@ def bzzzzzzz(gyver):
]
)

def easteregged(environ, start_response):
def injecting_start_response(status, headers, exc_info=None):
def easteregged(
environ: "WSGIEnvironment", start_response: "StartResponse"
) -> t.Iterable[bytes]:
def injecting_start_response(
status: str, headers: t.List[t.Tuple[str, str]], exc_info: t.Any = None
) -> t.Callable[[bytes], t.Any]:
headers.append(("X-Powered-By", "Werkzeug"))
return start_response(status, headers, exc_info)

Expand Down
14 changes: 7 additions & 7 deletions src/werkzeug/_reloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def _find_common_roots(paths: t.Iterable[str]) -> t.Iterable[str]:

rv = set()

def _walk(node, path):
def _walk(node: t.Mapping[str, dict], path: t.Tuple[str, ...]) -> None:
for prefix, child in node.items():
_walk(child, path + (prefix,))

Expand Down Expand Up @@ -218,7 +218,7 @@ def __enter__(self) -> "ReloaderLoop":
self.run_step()
return self

def __exit__(self, exc_type, exc_val, exc_tb):
def __exit__(self, exc_type, exc_val, exc_tb): # type: ignore
"""Clean up any resources associated with the reloader."""
pass

Expand Down Expand Up @@ -284,15 +284,15 @@ def run_step(self) -> None:


class WatchdogReloaderLoop(ReloaderLoop):
def __init__(self, *args, **kwargs) -> None:
def __init__(self, *args: t.Any, **kwargs: t.Any) -> None:
from watchdog.observers import Observer
from watchdog.events import PatternMatchingEventHandler

super().__init__(*args, **kwargs)
trigger_reload = self.trigger_reload

class EventHandler(PatternMatchingEventHandler): # type: ignore
def on_any_event(self, event):
def on_any_event(self, event): # type: ignore
trigger_reload(event.src_path)

reloader_name = Observer.__name__.lower()
Expand Down Expand Up @@ -331,7 +331,7 @@ def __enter__(self) -> ReloaderLoop:
self.observer.start()
return super().__enter__()

def __exit__(self, exc_type, exc_val, exc_tb):
def __exit__(self, exc_type, exc_val, exc_tb): # type: ignore
self.observer.stop()
self.observer.join()

Expand Down Expand Up @@ -379,7 +379,7 @@ def run_step(self) -> None:
reloader_loops["auto"] = reloader_loops["watchdog"]


def ensure_echo_on():
def ensure_echo_on() -> None:
"""Ensure that echo mode is enabled. Some tools such as PDB disable
it which causes usability issues after a reload."""
# tcgetattr will fail if stdin isn't a tty
Expand All @@ -404,7 +404,7 @@ def run_with_reloader(
exclude_patterns: t.Optional[t.Iterable[str]] = None,
interval: t.Union[int, float] = 1,
reloader_type: str = "auto",
):
) -> None:
"""Run the given function in an independent Python interpreter."""
import signal

Expand Down
Loading

0 comments on commit d6f89ee

Please sign in to comment.