Skip to content

Commit 15b635e

Browse files
committed
typeops: extend make_simplified_union fast path to enums
In PR python#9192 a fast path was created to address the slowness reported in issue python#9169 wherein large Union or literal types would dramatically slow down typechecking. It is desirable to extend this fast path to cover Enum types, as these can also leverage the O(n) set-based fast path instead of the O(n**2) fallback. This is seen to bring down the typechecking of a single fairly simple chain of `if` statements operating on a large enum (~3k members) from ~40min to 12s in real-world code! Note that the timing is taken from a pure-python run of mypy, as opposed to a compiled version.
1 parent 13ae58f commit 15b635e

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

Diff for: mypy/typeops.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
since these may assume that MROs are ready.
66
"""
77

8-
from typing import cast, Optional, List, Sequence, Set, Iterable, TypeVar
8+
from typing import cast, Optional, List, Sequence, Set, Iterable, TypeVar, Tuple
99
from typing_extensions import Type as TypingType
1010
import sys
1111

@@ -346,16 +346,17 @@ def make_simplified_union(items: Sequence[Type],
346346
removed = set() # type: Set[int]
347347

348348
# Avoid slow nested for loop for Union of Literal of strings (issue #9169)
349-
if all((isinstance(item, LiteralType) and
350-
item.fallback.type.fullname == 'builtins.str')
351-
for item in items):
352-
seen = set() # type: Set[str]
349+
if all((isinstance(item, LiteralType) and (
350+
item.fallback.type.is_enum or item.fallback.type.fullname == 'builtins.str'
351+
)) for item in items):
352+
seen = set() # type: Set[Tuple[str, str]]
353353
for index, item in enumerate(items):
354354
assert isinstance(item, LiteralType)
355355
assert isinstance(item.value, str)
356-
if item.value in seen:
356+
k = (item.value, item.fallback.type.fullname)
357+
if k in seen:
357358
removed.add(index)
358-
seen.add(item.value)
359+
seen.add(k)
359360

360361
else:
361362
for i, ti in enumerate(items):

0 commit comments

Comments
 (0)