From 9101707bd0c96624d09cb31fe573d7e25c89a35c Mon Sep 17 00:00:00 2001 From: Michael Lee Date: Tue, 24 Dec 2019 20:29:32 -0800 Subject: [PATCH] Make reachability code understand chained comparisons (v2) (#8148) This pull request is v2 (well, more like v10...) of my attempts to make our reachability code better understand chained comparisons. Unlike https://github.com/python/mypy/pull/7169, this diff focuses exclusively on adding support for chained operation comparisons and deliberately does not attempt to change any of the semantics of how identity and equality operations are performed. Specifically, mypy currently only examines the first two operands within a comparison expression when refining types. That means the following expressions all do not behave as expected: ```python x: MyEnum y: MyEnum if x is y is MyEnum.A: # x and y are not narrowed at all if x is MyEnum.A is y: # Only x is narrowed to Literal[MyEnum.A] ``` This pull request fixes this so we correctly infer the literal type for x and y in both conditionals. Some additional notes: 1. While analyzing our codebase, I found that while comparison expressions involving two or more `is` or `==` operators were somewhat common, there were almost no comparisons involving chains of `!=` or `is not` operators, and no comparisons involving "disjoint chains" -- e.g. expressions like `a == b < c == b` where there are multiple "disjoint" chains of equality comparisons. So, this diff is primarily designed to handle the case where a comparison expression has just one chain of `is` or `==`. For all other cases, I fall back to the more naive strategy of evaluating each comparison individually and and-ing the inferred types together without attempting to propagate any info. 2. I tested this code against one of our internal codebases. This ended up making mypy produce 3 or 4 new errors, but they all seemed legitimate, as far as I can tell. 3. I plan on submitting a follow-up diff that takes advantage of the work done in this diff to complete support for tagged unions using any Literal key, as previously promised. (I tried adding support for tagged unions in this diff, but attempting to simultaneously add support for chained comparisons while overhauling the semantics of `==` proved to be a little too overwhelming for me. So, baby steps.) --- mypy/checker.py | 603 ++++++++++++++++++++++++++--- mypy/nodes.py | 7 + mypy/test/testinfer.py | 239 +++++++++++- test-data/unit/check-enum.test | 150 +++++++ test-data/unit/check-optional.test | 22 ++ 5 files changed, 956 insertions(+), 65 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 8528bf35248d..ae829d1157c1 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -5,8 +5,8 @@ from contextlib import contextmanager from typing import ( - Dict, Set, List, cast, Tuple, TypeVar, Union, Optional, NamedTuple, Iterator, Sequence, - Mapping, + Dict, Set, List, cast, Tuple, TypeVar, Union, Optional, NamedTuple, Iterator, Iterable, + Sequence, Mapping, Generic, AbstractSet ) from typing_extensions import Final @@ -27,7 +27,7 @@ is_final_node, ARG_NAMED) from mypy import nodes -from mypy.literals import literal, literal_hash +from mypy.literals import literal, literal_hash, Key from mypy.typeanal import has_any_from_unimported_type, check_for_explicit_any from mypy.types import ( Type, AnyType, CallableType, FunctionLike, Overloaded, TupleType, TypedDictType, @@ -3842,67 +3842,101 @@ def find_isinstance_check_helper(self, node: Expression) -> Tuple[TypeMap, TypeM vartype = type_map[expr] return self.conditional_callable_type_map(expr, vartype) elif isinstance(node, ComparisonExpr): - operand_types = [coerce_to_literal(type_map[expr]) - for expr in node.operands if expr in type_map] - - is_not = node.operators == ['is not'] - if (is_not or node.operators == ['is']) and len(operand_types) == len(node.operands): - if_vars = {} # type: TypeMap - else_vars = {} # type: TypeMap - - for i, expr in enumerate(node.operands): - var_type = operand_types[i] - other_type = operand_types[1 - i] - - if literal(expr) == LITERAL_TYPE and is_singleton_type(other_type): - # This should only be true at most once: there should be - # exactly two elements in node.operands and if the 'other type' is - # a singleton type, it by definition does not need to be narrowed: - # it already has the most precise type possible so does not need to - # be narrowed/included in the output map. - # - # TODO: Generalize this to handle the case where 'other_type' is - # a union of singleton types. - - if isinstance(other_type, LiteralType) and other_type.is_enum_literal(): - fallback_name = other_type.fallback.type.fullname - var_type = try_expanding_enum_to_union(var_type, fallback_name) - - target_type = [TypeRange(other_type, is_upper_bound=False)] - if_vars, else_vars = conditional_type_map(expr, var_type, target_type) - break + # Step 1: Obtain the types of each operand and whether or not we can + # narrow their types. (For example, we shouldn't try narrowing the + # types of literal string or enum expressions). + + operands = node.operands + operand_types = [] + narrowable_operand_index_to_hash = {} + for i, expr in enumerate(operands): + if expr not in type_map: + return {}, {} + expr_type = type_map[expr] + operand_types.append(expr_type) + + if (literal(expr) == LITERAL_TYPE + and not is_literal_none(expr) + and not is_literal_enum(type_map, expr)): + h = literal_hash(expr) + if h is not None: + narrowable_operand_index_to_hash[i] = h + + # Step 2: Group operands chained by either the 'is' or '==' operands + # together. For all other operands, we keep them in groups of size 2. + # So the expression: + # + # x0 == x1 == x2 < x3 < x4 is x5 is x6 is not x7 is not x8 + # + # ...is converted into the simplified operator list: + # + # [("==", [0, 1, 2]), ("<", [2, 3]), ("<", [3, 4]), + # ("is", [4, 5, 6]), ("is not", [6, 7]), ("is not", [7, 8])] + # + # We group identity/equality expressions so we can propagate information + # we discover about one operand across the entire chain. We don't bother + # handling 'is not' and '!=' chains in a special way: those are very rare + # in practice. + + simplified_operator_list = group_comparison_operands( + node.pairwise(), + narrowable_operand_index_to_hash, + {'==', 'is'}, + ) + + # Step 3: Analyze each group and infer more precise type maps for each + # assignable operand, if possible. We combine these type maps together + # in the final step. + + partial_type_maps = [] + for operator, expr_indices in simplified_operator_list: + if operator in {'is', 'is not'}: + if_map, else_map = self.refine_identity_comparison_expression( + operands, + operand_types, + expr_indices, + narrowable_operand_index_to_hash.keys(), + ) + elif operator in {'==', '!='}: + if_map, else_map = self.refine_equality_comparison_expression( + operands, + operand_types, + expr_indices, + narrowable_operand_index_to_hash.keys(), + ) + elif operator in {'in', 'not in'}: + assert len(expr_indices) == 2 + left_index, right_index = expr_indices + if left_index not in narrowable_operand_index_to_hash: + continue + + item_type = operand_types[left_index] + collection_type = operand_types[right_index] + + # We only try and narrow away 'None' for now + if not is_optional(item_type): + pass - if is_not: - if_vars, else_vars = else_vars, if_vars - return if_vars, else_vars - # Check for `x == y` where x is of type Optional[T] and y is of type T - # or a type that overlaps with T (or vice versa). - elif node.operators == ['==']: - first_type = type_map[node.operands[0]] - second_type = type_map[node.operands[1]] - if is_optional(first_type) != is_optional(second_type): - if is_optional(first_type): - optional_type, comp_type = first_type, second_type - optional_expr = node.operands[0] + collection_item_type = get_proper_type(builtin_item_type(collection_type)) + if collection_item_type is None or is_optional(collection_item_type): + continue + if (isinstance(collection_item_type, Instance) + and collection_item_type.type.fullname == 'builtins.object'): + continue + if is_overlapping_erased_types(item_type, collection_item_type): + if_map, else_map = {operands[left_index]: remove_optional(item_type)}, {} else: - optional_type, comp_type = second_type, first_type - optional_expr = node.operands[1] - if is_overlapping_erased_types(optional_type, comp_type): - return {optional_expr: remove_optional(optional_type)}, {} - elif node.operators in [['in'], ['not in']]: - expr = node.operands[0] - left_type = type_map[expr] - right_type = get_proper_type(builtin_item_type(type_map[node.operands[1]])) - right_ok = right_type and (not is_optional(right_type) and - (not isinstance(right_type, Instance) or - right_type.type.fullname != 'builtins.object')) - if (right_type and right_ok and is_optional(left_type) and - literal(expr) == LITERAL_TYPE and not is_literal_none(expr) and - is_overlapping_erased_types(left_type, right_type)): - if node.operators == ['in']: - return {expr: remove_optional(left_type)}, {} - if node.operators == ['not in']: - return {}, {expr: remove_optional(left_type)} + continue + else: + if_map = {} + else_map = {} + + if operator in {'is not', '!=', 'not in'}: + if_map, else_map = else_map, if_map + + partial_type_maps.append((if_map, else_map)) + + return reduce_partial_conditional_maps(partial_type_maps) elif isinstance(node, RefExpr): # Restrict the type of the variable to True-ish/False-ish in the if and else branches # respectively @@ -4107,6 +4141,143 @@ def replay_lookup(new_parent_type: ProperType) -> Optional[Type]: return output + def refine_identity_comparison_expression(self, + operands: List[Expression], + operand_types: List[Type], + chain_indices: List[int], + narrowable_operand_indices: AbstractSet[int], + ) -> Tuple[TypeMap, TypeMap]: + """Produces conditional type maps refining expressions used in an identity comparison. + + The 'operands' and 'operand_types' lists should be the full list of operands used + in the overall comparison expression. The 'chain_indices' list is the list of indices + actually used within this identity comparison chain. + + So if we have the expression: + + a <= b is c is d <= e + + ...then 'operands' and 'operand_types' would be lists of length 5 and 'chain_indices' + would be the list [1, 2, 3]. + + The 'narrowable_operand_indices' parameter is the set of all indices we are allowed + to refine the types of: that is, all operands that will potentially be a part of + the output TypeMaps. + """ + singleton = None # type: Optional[ProperType] + possible_singleton_indices = [] + for i in chain_indices: + coerced_type = coerce_to_literal(operand_types[i]) + if not is_singleton_type(coerced_type): + continue + if singleton and not is_same_type(singleton, coerced_type): + # We have multiple disjoint singleton types. So the 'if' branch + # must be unreachable. + return None, {} + singleton = coerced_type + possible_singleton_indices.append(i) + + # There's nothing we can currently infer if none of the operands are singleton types, + # so we end early and infer nothing. + if singleton is None: + return {}, {} + + # If possible, use an unassignable expression as the singleton. + # We skip refining the type of the singleton below, so ideally we'd + # want to pick an expression we were going to skip anyways. + singleton_index = -1 + for i in possible_singleton_indices: + if i not in narrowable_operand_indices: + singleton_index = i + + # But if none of the possible singletons are unassignable ones, we give up + # and arbitrarily pick the last item, mostly because other parts of the + # type narrowing logic bias towards picking the rightmost item and it'd be + # nice to stay consistent. + # + # That said, it shouldn't matter which index we pick. For example, suppose we + # have this if statement, where 'x' and 'y' both have singleton types: + # + # if x is y: + # reveal_type(x) + # reveal_type(y) + # else: + # reveal_type(x) + # reveal_type(y) + # + # At this point, 'x' and 'y' *must* have the same singleton type: we would have + # ended early in the first for-loop in this function if they weren't. + # + # So, we should always get the same result in the 'if' case no matter which + # index we pick. And while we do end up getting different results in the 'else' + # case depending on the index (e.g. if we pick 'y', then its type stays the same + # while 'x' is narrowed to ''), this distinction is also moot: mypy + # currently will just mark the whole branch as unreachable if either operand is + # narrowed to . + if singleton_index == -1: + singleton_index = possible_singleton_indices[-1] + + enum_name = None + if isinstance(singleton, LiteralType) and singleton.is_enum_literal(): + enum_name = singleton.fallback.type.fullname + + target_type = [TypeRange(singleton, is_upper_bound=False)] + + partial_type_maps = [] + for i in chain_indices: + # If we try refining a singleton against itself, conditional_type_map + # will end up assuming that the 'else' branch is unreachable. This is + # typically not what we want: generally the user will intend for the + # singleton type to be some fixed 'sentinel' value and will want to refine + # the other exprs against this one instead. + if i == singleton_index: + continue + + # Naturally, we can't refine operands which are not permitted to be refined. + if i not in narrowable_operand_indices: + continue + + expr = operands[i] + expr_type = coerce_to_literal(operand_types[i]) + + if enum_name is not None: + expr_type = try_expanding_enum_to_union(expr_type, enum_name) + partial_type_maps.append(conditional_type_map(expr, expr_type, target_type)) + + return reduce_partial_conditional_maps(partial_type_maps) + + def refine_equality_comparison_expression(self, + operands: List[Expression], + operand_types: List[Type], + chain_indices: List[int], + narrowable_operand_indices: AbstractSet[int], + ) -> Tuple[TypeMap, TypeMap]: + """Produces conditional type maps refining expressions used in an equality comparison. + + For more details, see the docstring of 'refine_equality_comparison' up above. + The only difference is that this function is for refining equality operations + (e.g. 'a == b == c') instead of identity ('a is b is c'). + """ + non_optional_types = [] + for i in chain_indices: + typ = operand_types[i] + if not is_optional(typ): + non_optional_types.append(typ) + + # Make sure we have a mixture of optional and non-optional types. + if len(non_optional_types) == 0 or len(non_optional_types) == len(chain_indices): + return {}, {} + + if_map = {} + for i in narrowable_operand_indices: + expr_type = operand_types[i] + if not is_optional(expr_type): + continue + if any(is_overlapping_erased_types(expr_type, t) for t in non_optional_types): + if_map[operands[i]] = remove_optional(expr_type) + + return if_map, {} + # # Helpers # @@ -4541,16 +4712,55 @@ def gen_unique_name(base: str, table: SymbolTable) -> str: def is_true_literal(n: Expression) -> bool: + """Returns true if this expression is the 'True' literal/keyword.""" return (refers_to_fullname(n, 'builtins.True') or isinstance(n, IntExpr) and n.value == 1) def is_false_literal(n: Expression) -> bool: + """Returns true if this expression is the 'False' literal/keyword.""" return (refers_to_fullname(n, 'builtins.False') or isinstance(n, IntExpr) and n.value == 0) +def is_literal_enum(type_map: Mapping[Expression, Type], n: Expression) -> bool: + """Returns true if this expression (with the given type context) is an Enum literal. + + For example, if we had an enum: + + class Foo(Enum): + A = 1 + B = 2 + + ...and if the expression 'Foo' referred to that enum within the current type context, + then the expression 'Foo.A' would be a a literal enum. However, if we did 'a = Foo.A', + then the variable 'a' would *not* be a literal enum. + + We occasionally special-case expressions like 'Foo.A' and treat them as a single primitive + unit for the same reasons we sometimes treat 'True', 'False', or 'None' as a single + primitive unit. + """ + if not isinstance(n, MemberExpr) or not isinstance(n.expr, NameExpr): + return False + + parent_type = type_map.get(n.expr) + member_type = type_map.get(n) + if member_type is None or parent_type is None: + return False + + parent_type = get_proper_type(parent_type) + member_type = coerce_to_literal(member_type) + if not isinstance(parent_type, FunctionLike) or not isinstance(member_type, LiteralType): + return False + + if not parent_type.is_type_obj(): + return False + + return member_type.is_enum_literal() and member_type.fallback.type == parent_type.type_object() + + def is_literal_none(n: Expression) -> bool: + """Returns true if this expression is the 'None' literal/keyword.""" return isinstance(n, NameExpr) and n.fullname == 'builtins.None' @@ -4641,6 +4851,76 @@ def or_conditional_maps(m1: TypeMap, m2: TypeMap) -> TypeMap: return result +def or_partial_conditional_maps(m1: TypeMap, m2: TypeMap) -> TypeMap: + """Calculate what information we can learn from the truth of (e1 or e2) + in terms of the information that we can learn from the truth of e1 and + the truth of e2. + + Unlike 'or_conditional_maps', we include an expression in the output even + if it exists in only one map: we're assuming both maps are "partial" and + contain information about only some expressions, and so we "or" together + expressions both maps have information on. + """ + + if m1 is None: + return m2 + if m2 is None: + return m1 + # The logic here is a blend between 'and_conditional_maps' + # and 'or_conditional_maps'. We use the high-level logic from the + # former to ensure all expressions make it in the output map, + # but resolve cases where both maps contain info on the same + # expr using the unioning strategy from the latter. + result = m2.copy() + m2_keys = {literal_hash(n2): n2 for n2 in m2} + for n1 in m1: + n2 = m2_keys.get(literal_hash(n1)) + if n2 is None: + result[n1] = m1[n1] + else: + result[n2] = make_simplified_union([m1[n1], result[n2]]) + + return result + + +def reduce_partial_conditional_maps(type_maps: List[Tuple[TypeMap, TypeMap]], + ) -> Tuple[TypeMap, TypeMap]: + """Reduces a list containing pairs of *partial* if/else TypeMaps into a single pair. + + That is, if a expression exists in only one map, we always include it in the output. + We only "and"/"or" together expressions that appear in multiple if/else maps. + + So for example, if we had the input: + + [ + ({x: TypeIfX, shared: TypeIfShared1}, {x: TypeElseX, shared: TypeElseShared1}), + ({y: TypeIfY, shared: TypeIfShared2}, {y: TypeElseY, shared: TypeElseShared2}), + ] + + ...we'd return the output: + + ( + {x: TypeIfX, y: TypeIfY, shared: PseudoIntersection[TypeIfShared1, TypeIfShared2]}, + {x: TypeElseX, y: TypeElseY, shared: Union[TypeElseShared1, TypeElseShared2]}, + ) + + ...where "PseudoIntersection[X, Y] == Y" because mypy actually doesn't understand intersections + yet, so we settle for just arbitrarily picking the right expr's type. + """ + if len(type_maps) == 0: + return {}, {} + elif len(type_maps) == 1: + return type_maps[0] + else: + final_if_map, final_else_map = type_maps[0] + for if_map, else_map in type_maps[1:]: + # 'and_conditional_maps' does the same thing for both global and partial type maps, + # which is why we don't need to have an 'and_partial_conditional_maps' function. + final_if_map = and_conditional_maps(final_if_map, if_map) + final_else_map = or_partial_conditional_maps(final_else_map, else_map) + return final_if_map, final_else_map + + def convert_to_typetype(type_map: TypeMap) -> TypeMap: converted_type_map = {} # type: Dict[Expression, Type] if type_map is None: @@ -5007,6 +5287,205 @@ def nothing() -> Iterator[None]: yield +TKey = TypeVar('TKey') +TValue = TypeVar('TValue') + + +class DisjointDict(Generic[TKey, TValue]): + """An variation of the union-find algorithm/data structure where instead of keeping + track of just disjoint sets, we keep track of disjoint dicts -- keep track of multiple + Set[Key] -> Set[Value] mappings, where each mapping's keys are guaranteed to be disjoint. + + This data structure is currently used exclusively by 'group_comparison_operands' below + to merge chains of '==' and 'is' comparisons when two or more chains use the same expression + in best-case O(n), where n is the number of operands. + + Specifically, the `add_mapping()` function and `items()` functions will take on average + O(k + v) and O(n) respectively, where k and v are the number of keys and values we're adding + for a given chain. Note that k <= n and v <= n. + + We hit these average/best-case scenarios for most user code: e.g. when the user has just + a single chain like 'a == b == c == d == ...' or multiple disjoint chains like + 'a==b < c==d < e==f < ...'. (Note that a naive iterative merging would be O(n^2) for + the latter case). + + In comparison, this data structure will make 'group_comparison_operands' have a worst-case + runtime of O(n*log(n)): 'add_mapping()' and 'items()' are worst-case O(k*log(n) + v) and + O(k*log(n)) respectively. This happens only in the rare case where the user keeps repeatedly + making disjoint mappings before merging them in a way that persistently dodges the path + compression optimization in '_lookup_root_id', which would end up constructing a single + tree of height log_2(n). This makes root lookups no longer amoritized constant time when we + finally call 'items()'. + """ + def __init__(self) -> None: + # Each key maps to a unique ID + self._key_to_id = {} # type: Dict[TKey, int] + + # Each id points to the parent id, forming a forest of upwards-pointing trees. If the + # current id already is the root, it points to itself. We gradually flatten these trees + # as we perform root lookups: eventually all nodes point directly to its root. + self._id_to_parent_id = {} # type: Dict[int, int] + + # Each root id in turn maps to the set of values. + self._root_id_to_values = {} # type: Dict[int, Set[TValue]] + + def add_mapping(self, keys: Set[TKey], values: Set[TValue]) -> None: + """Adds a 'Set[TKey] -> Set[TValue]' mapping. If there already exists a mapping + containing one or more of the given keys, we merge the input mapping with the old one. + + Note that the given set of keys must be non-empty -- otherwise, nothing happens. + """ + if len(keys) == 0: + return + + subtree_roots = [self._lookup_or_make_root_id(key) for key in keys] + new_root = subtree_roots[0] + + root_values = self._root_id_to_values[new_root] + root_values.update(values) + for subtree_root in subtree_roots[1:]: + if subtree_root == new_root or subtree_root not in self._root_id_to_values: + continue + self._id_to_parent_id[subtree_root] = new_root + root_values.update(self._root_id_to_values.pop(subtree_root)) + + def items(self) -> List[Tuple[Set[TKey], Set[TValue]]]: + """Returns all disjoint mappings in key-value pairs.""" + root_id_to_keys = {} # type: Dict[int, Set[TKey]] + for key in self._key_to_id: + root_id = self._lookup_root_id(key) + if root_id not in root_id_to_keys: + root_id_to_keys[root_id] = set() + root_id_to_keys[root_id].add(key) + + output = [] + for root_id, keys in root_id_to_keys.items(): + output.append((keys, self._root_id_to_values[root_id])) + + return output + + def _lookup_or_make_root_id(self, key: TKey) -> int: + if key in self._key_to_id: + return self._lookup_root_id(key) + else: + new_id = len(self._key_to_id) + self._key_to_id[key] = new_id + self._id_to_parent_id[new_id] = new_id + self._root_id_to_values[new_id] = set() + return new_id + + def _lookup_root_id(self, key: TKey) -> int: + i = self._key_to_id[key] + while i != self._id_to_parent_id[i]: + # Optimization: make keys directly point to their grandparents to speed up + # future traversals. This prevents degenerate trees of height n from forming. + new_parent = self._id_to_parent_id[self._id_to_parent_id[i]] + self._id_to_parent_id[i] = new_parent + i = new_parent + return i + + +def group_comparison_operands(pairwise_comparisons: Iterable[Tuple[str, Expression, Expression]], + operand_to_literal_hash: Mapping[int, Key], + operators_to_group: Set[str], + ) -> List[Tuple[str, List[int]]]: + """Group a series of comparison operands together chained by any operand + in the 'operators_to_group' set. All other pairwise operands are kept in + groups of size 2. + + For example, suppose we have the input comparison expression: + + x0 == x1 == x2 < x3 < x4 is x5 is x6 is not x7 is not x8 + + If we get these expressions in a pairwise way (e.g. by calling ComparisionExpr's + 'pairwise()' method), we get the following as input: + + [('==', x0, x1), ('==', x1, x2), ('<', x2, x3), ('<', x3, x4), + ('is', x4, x5), ('is', x5, x6), ('is not', x6, x7), ('is not', x7, x8)] + + If `operators_to_group` is the set {'==', 'is'}, this function will produce + the following "simplified operator list": + + [("==", [0, 1, 2]), ("<", [2, 3]), ("<", [3, 4]), + ("is", [4, 5, 6]), ("is not", [6, 7]), ("is not", [7, 8])] + + Note that (a) we yield *indices* to the operands rather then the operand + expressions themselves and that (b) operands used in a consecutive chain + of '==' or 'is' are grouped together. + + If two of these chains happen to contain operands with the same underlying + literal hash (e.g. are assignable and correspond to the same expression), + we combine those chains together. For example, if we had: + + same == x < y == same + + ...and if 'operand_to_literal_hash' contained the same values for the indices + 0 and 3, we'd produce the following output: + + [("==", [0, 1, 2, 3]), ("<", [1, 2])] + + But if the 'operand_to_literal_hash' did *not* contain an entry, we'd instead + default to returning: + + [("==", [0, 1]), ("<", [1, 2]), ("==", [2, 3])] + + This function is currently only used to assist with type-narrowing refinements + and is extracted out to a helper function so we can unit test it. + """ + groups = { + op: DisjointDict() for op in operators_to_group + } # type: Dict[str, DisjointDict[Key, int]] + + simplified_operator_list = [] # type: List[Tuple[str, List[int]]] + last_operator = None # type: Optional[str] + current_indices = set() # type: Set[int] + current_hashes = set() # type: Set[Key] + for i, (operator, left_expr, right_expr) in enumerate(pairwise_comparisons): + if last_operator is None: + last_operator = operator + + if current_indices and (operator != last_operator or operator not in operators_to_group): + # If some of the operands in the chain are assignable, defer adding it: we might + # end up needing to merge it with other chains that appear later. + if len(current_hashes) == 0: + simplified_operator_list.append((last_operator, sorted(current_indices))) + else: + groups[last_operator].add_mapping(current_hashes, current_indices) + last_operator = operator + current_indices = set() + current_hashes = set() + + # Note: 'i' corresponds to the left operand index, so 'i + 1' is the + # right operand. + current_indices.add(i) + current_indices.add(i + 1) + + # We only ever want to combine operands/combine chains for these operators + if operator in operators_to_group: + left_hash = operand_to_literal_hash.get(i) + if left_hash is not None: + current_hashes.add(left_hash) + right_hash = operand_to_literal_hash.get(i + 1) + if right_hash is not None: + current_hashes.add(right_hash) + + if last_operator is not None: + if len(current_hashes) == 0: + simplified_operator_list.append((last_operator, sorted(current_indices))) + else: + groups[last_operator].add_mapping(current_hashes, current_indices) + + # Now that we know which chains happen to contain the same underlying expressions + # and can be merged together, add in this info back to the output. + for operator, disjoint_dict in groups.items(): + for keys, indices in disjoint_dict.items(): + simplified_operator_list.append((operator, sorted(indices))) + + # For stability, reorder list by the first operand index to appear + simplified_operator_list.sort(key=lambda item: item[1][0]) + return simplified_operator_list + + def is_typed_callable(c: Optional[Type]) -> bool: c = get_proper_type(c) if not c or not isinstance(c, CallableType): diff --git a/mypy/nodes.py b/mypy/nodes.py index 4ee3948fedd3..792a89a5fea4 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -1750,6 +1750,13 @@ def __init__(self, operators: List[str], operands: List[Expression]) -> None: self.operands = operands self.method_types = [] + def pairwise(self) -> Iterator[Tuple[str, Expression, Expression]]: + """If this comparison expr is "a < b is c == d", yields the sequence + ("<", a, b), ("is", b, c), ("==", c, d) + """ + for i, operator in enumerate(self.operators): + yield operator, self.operands[i], self.operands[i + 1] + def accept(self, visitor: ExpressionVisitor[T]) -> T: return visitor.visit_comparison_expr(self) diff --git a/mypy/test/testinfer.py b/mypy/test/testinfer.py index 2e26e99453b8..e70d74530a99 100644 --- a/mypy/test/testinfer.py +++ b/mypy/test/testinfer.py @@ -1,16 +1,18 @@ """Test cases for type inference helper functions.""" -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union, Dict, Set from mypy.test.helpers import Suite, assert_equal from mypy.argmap import map_actuals_to_formals -from mypy.nodes import ARG_POS, ARG_OPT, ARG_STAR, ARG_STAR2, ARG_NAMED +from mypy.checker import group_comparison_operands, DisjointDict +from mypy.literals import Key +from mypy.nodes import ARG_POS, ARG_OPT, ARG_STAR, ARG_STAR2, ARG_NAMED, NameExpr from mypy.types import AnyType, TupleType, Type, TypeOfAny from mypy.test.typefixture import TypeFixture class MapActualsToFormalsSuite(Suite): - """Test cases for checkexpr.map_actuals_to_formals.""" + """Test cases for argmap.map_actuals_to_formals.""" def test_basic(self) -> None: self.assert_map([], [], []) @@ -223,3 +225,234 @@ def expand_callee_kinds(kinds_and_names: List[Union[int, Tuple[int, str]]] kinds.append(v) names.append(None) return kinds, names + + +class OperandDisjointDictSuite(Suite): + """Test cases for checker.DisjointDict, which is used for type inference with operands.""" + def new(self) -> DisjointDict[int, str]: + return DisjointDict() + + def test_independent_maps(self) -> None: + d = self.new() + d.add_mapping({0, 1}, {"group1"}) + d.add_mapping({2, 3, 4}, {"group2"}) + d.add_mapping({5, 6, 7}, {"group3"}) + + self.assertEqual(d.items(), [ + ({0, 1}, {"group1"}), + ({2, 3, 4}, {"group2"}), + ({5, 6, 7}, {"group3"}), + ]) + + def test_partial_merging(self) -> None: + d = self.new() + d.add_mapping({0, 1}, {"group1"}) + d.add_mapping({1, 2}, {"group2"}) + d.add_mapping({3, 4}, {"group3"}) + d.add_mapping({5, 0}, {"group4"}) + d.add_mapping({5, 6}, {"group5"}) + d.add_mapping({4, 7}, {"group6"}) + + self.assertEqual(d.items(), [ + ({0, 1, 2, 5, 6}, {"group1", "group2", "group4", "group5"}), + ({3, 4, 7}, {"group3", "group6"}), + ]) + + def test_full_merging(self) -> None: + d = self.new() + d.add_mapping({0, 1, 2}, {"a"}) + d.add_mapping({3, 4, 2}, {"b"}) + d.add_mapping({10, 11, 12}, {"c"}) + d.add_mapping({13, 14, 15}, {"d"}) + d.add_mapping({14, 10, 16}, {"e"}) + d.add_mapping({0, 10}, {"f"}) + + self.assertEqual(d.items(), [ + ({0, 1, 2, 3, 4, 10, 11, 12, 13, 14, 15, 16}, {"a", "b", "c", "d", "e", "f"}), + ]) + + def test_merge_with_multiple_overlaps(self) -> None: + d = self.new() + d.add_mapping({0, 1, 2}, {"a"}) + d.add_mapping({3, 4, 5}, {"b"}) + d.add_mapping({1, 2, 4, 5}, {"c"}) + d.add_mapping({6, 1, 2, 4, 5}, {"d"}) + d.add_mapping({6, 1, 2, 4, 5}, {"e"}) + + self.assertEqual(d.items(), [ + ({0, 1, 2, 3, 4, 5, 6}, {"a", "b", "c", "d", "e"}), + ]) + + +class OperandComparisonGroupingSuite(Suite): + """Test cases for checker.group_comparison_operands.""" + def literal_keymap(self, assignable_operands: Dict[int, NameExpr]) -> Dict[int, Key]: + output = {} # type: Dict[int, Key] + for index, expr in assignable_operands.items(): + output[index] = ('FakeExpr', expr.name) + return output + + def test_basic_cases(self) -> None: + # Note: the grouping function doesn't actually inspect the input exprs, so we + # just default to using NameExprs for simplicity. + x0 = NameExpr('x0') + x1 = NameExpr('x1') + x2 = NameExpr('x2') + x3 = NameExpr('x3') + x4 = NameExpr('x4') + + basic_input = [('==', x0, x1), ('==', x1, x2), ('<', x2, x3), ('==', x3, x4)] + + none_assignable = self.literal_keymap({}) + all_assignable = self.literal_keymap({0: x0, 1: x1, 2: x2, 3: x3, 4: x4}) + + for assignable in [none_assignable, all_assignable]: + self.assertEqual( + group_comparison_operands(basic_input, assignable, set()), + [('==', [0, 1]), ('==', [1, 2]), ('<', [2, 3]), ('==', [3, 4])], + ) + self.assertEqual( + group_comparison_operands(basic_input, assignable, {'=='}), + [('==', [0, 1, 2]), ('<', [2, 3]), ('==', [3, 4])], + ) + self.assertEqual( + group_comparison_operands(basic_input, assignable, {'<'}), + [('==', [0, 1]), ('==', [1, 2]), ('<', [2, 3]), ('==', [3, 4])], + ) + self.assertEqual( + group_comparison_operands(basic_input, assignable, {'==', '<'}), + [('==', [0, 1, 2]), ('<', [2, 3]), ('==', [3, 4])], + ) + + def test_multiple_groups(self) -> None: + x0 = NameExpr('x0') + x1 = NameExpr('x1') + x2 = NameExpr('x2') + x3 = NameExpr('x3') + x4 = NameExpr('x4') + x5 = NameExpr('x5') + + self.assertEqual( + group_comparison_operands( + [('==', x0, x1), ('==', x1, x2), ('is', x2, x3), ('is', x3, x4)], + self.literal_keymap({}), + {'==', 'is'}, + ), + [('==', [0, 1, 2]), ('is', [2, 3, 4])], + ) + self.assertEqual( + group_comparison_operands( + [('==', x0, x1), ('==', x1, x2), ('==', x2, x3), ('==', x3, x4)], + self.literal_keymap({}), + {'==', 'is'}, + ), + [('==', [0, 1, 2, 3, 4])], + ) + self.assertEqual( + group_comparison_operands( + [('is', x0, x1), ('==', x1, x2), ('==', x2, x3), ('==', x3, x4)], + self.literal_keymap({}), + {'==', 'is'}, + ), + [('is', [0, 1]), ('==', [1, 2, 3, 4])], + ) + self.assertEqual( + group_comparison_operands( + [('is', x0, x1), ('is', x1, x2), ('<', x2, x3), ('==', x3, x4), ('==', x4, x5)], + self.literal_keymap({}), + {'==', 'is'}, + ), + [('is', [0, 1, 2]), ('<', [2, 3]), ('==', [3, 4, 5])], + ) + + def test_multiple_groups_coalescing(self) -> None: + x0 = NameExpr('x0') + x1 = NameExpr('x1') + x2 = NameExpr('x2') + x3 = NameExpr('x3') + x4 = NameExpr('x4') + + nothing_combined = [('==', [0, 1, 2]), ('<', [2, 3]), ('==', [3, 4, 5])] + everything_combined = [('==', [0, 1, 2, 3, 4, 5]), ('<', [2, 3])] + + # Note: We do 'x4 == x0' at the very end! + two_groups = [ + ('==', x0, x1), ('==', x1, x2), ('<', x2, x3), ('==', x3, x4), ('==', x4, x0), + ] + self.assertEqual( + group_comparison_operands( + two_groups, + self.literal_keymap({0: x0, 1: x1, 2: x2, 3: x3, 4: x4, 5: x0}), + {'=='}, + ), + everything_combined, + "All vars are assignable, everything is combined" + ) + self.assertEqual( + group_comparison_operands( + two_groups, + self.literal_keymap({1: x1, 2: x2, 3: x3, 4: x4}), + {'=='}, + ), + nothing_combined, + "x0 is unassignable, so no combining" + ) + self.assertEqual( + group_comparison_operands( + two_groups, + self.literal_keymap({0: x0, 1: x1, 3: x3, 5: x0}), + {'=='}, + ), + everything_combined, + "Some vars are unassignable but x0 is, so we combine" + ) + self.assertEqual( + group_comparison_operands( + two_groups, + self.literal_keymap({0: x0, 5: x0}), + {'=='}, + ), + everything_combined, + "All vars are unassignable but x0 is, so we combine" + ) + + def test_multiple_groups_different_operators(self) -> None: + x0 = NameExpr('x0') + x1 = NameExpr('x1') + x2 = NameExpr('x2') + x3 = NameExpr('x3') + + groups = [('==', x0, x1), ('==', x1, x2), ('is', x2, x3), ('is', x3, x0)] + keymap = self.literal_keymap({0: x0, 1: x1, 2: x2, 3: x3, 4: x0}) + self.assertEqual( + group_comparison_operands(groups, keymap, {'==', 'is'}), + [('==', [0, 1, 2]), ('is', [2, 3, 4])], + "Different operators can never be combined" + ) + + def test_single_pair(self) -> None: + x0 = NameExpr('x0') + x1 = NameExpr('x1') + + single_comparison = [('==', x0, x1)] + expected_output = [('==', [0, 1])] + + assignable_combinations = [ + {}, {0: x0}, {1: x1}, {0: x0, 1: x1}, + ] # type: List[Dict[int, NameExpr]] + to_group_by = [set(), {'=='}, {'is'}] # type: List[Set[str]] + + for combo in assignable_combinations: + for operators in to_group_by: + keymap = self.literal_keymap(combo) + self.assertEqual( + group_comparison_operands(single_comparison, keymap, operators), + expected_output, + ) + + def test_empty_pair_list(self) -> None: + # This case should never occur in practice -- ComparisionExprs + # always contain at least one comparision. But in case it does... + + self.assertEqual(group_comparison_operands([], {}, set()), []) + self.assertEqual(group_comparison_operands([], {}, {'=='}), []) diff --git a/test-data/unit/check-enum.test b/test-data/unit/check-enum.test index 241cd1ca049c..9d027f47192f 100644 --- a/test-data/unit/check-enum.test +++ b/test-data/unit/check-enum.test @@ -967,3 +967,153 @@ class A: self.b = Enum("x", [("foo", "bar")]) # E: Enum type as attribute is not supported reveal_type(A().b) # N: Revealed type is 'Any' + +[case testEnumReachabilityWithChaining] +from enum import Enum + +class Foo(Enum): + A = 1 + B = 2 + +x: Foo +y: Foo + +if x is y is Foo.A: + reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]' + reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]' +else: + reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.B]' + reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.B]' +reveal_type(x) # N: Revealed type is '__main__.Foo' +reveal_type(y) # N: Revealed type is '__main__.Foo' + +if x is Foo.A is y: + reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]' + reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]' +else: + reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.B]' + reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.B]' +reveal_type(x) # N: Revealed type is '__main__.Foo' +reveal_type(y) # N: Revealed type is '__main__.Foo' + +if Foo.A is x is y: + reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]' + reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]' +else: + reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.B]' + reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.B]' +reveal_type(x) # N: Revealed type is '__main__.Foo' +reveal_type(y) # N: Revealed type is '__main__.Foo' + +[builtins fixtures/primitives.pyi] + +[case testEnumReachabilityWithChainingDisjoint] +# flags: --warn-unreachable +from enum import Enum + +class Foo(Enum): + A = 1 + B = 2 + + # Used to divide up a chained comparison into multiple identity groups + def __lt__(self, other: object) -> bool: return True + +x: Foo +y: Foo + +# No conflict +if x is Foo.A < y is Foo.B: + reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]' + reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.B]' +else: + reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.B]' + reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]' +reveal_type(x) # N: Revealed type is '__main__.Foo' +reveal_type(y) # N: Revealed type is '__main__.Foo' + +# The standard output when we end up inferring two disjoint facts about the same expr +if x is Foo.A and x is Foo.B: + reveal_type(x) # E: Statement is unreachable +else: + reveal_type(x) # N: Revealed type is '__main__.Foo' +reveal_type(x) # N: Revealed type is '__main__.Foo' + +# ..and we get the same result if we have two disjoint groups within the same comp expr +if x is Foo.A < x is Foo.B: + reveal_type(x) # E: Statement is unreachable +else: + reveal_type(x) # N: Revealed type is '__main__.Foo' +reveal_type(x) # N: Revealed type is '__main__.Foo' +[builtins fixtures/primitives.pyi] + +[case testEnumReachabilityWithChainingDirectConflict] +# flags: --warn-unreachable +from enum import Enum +from typing_extensions import Literal, Final + +class Foo(Enum): + A = 1 + B = 2 + C = 3 + +x: Foo +if x is Foo.A is Foo.B: + reveal_type(x) # E: Statement is unreachable +else: + reveal_type(x) # N: Revealed type is '__main__.Foo' +reveal_type(x) # N: Revealed type is '__main__.Foo' + +literal_a: Literal[Foo.A] +literal_b: Literal[Foo.B] +if x is literal_a is literal_b: + reveal_type(x) # E: Statement is unreachable +else: + reveal_type(x) # N: Revealed type is '__main__.Foo' +reveal_type(x) # N: Revealed type is '__main__.Foo' + +final_a: Final = Foo.A +final_b: Final = Foo.B +if x is final_a is final_b: + reveal_type(x) # E: Statement is unreachable +else: + reveal_type(x) # N: Revealed type is '__main__.Foo' +reveal_type(x) # N: Revealed type is '__main__.Foo' + +[builtins fixtures/primitives.pyi] + +[case testEnumReachabilityWithChainingBigDisjoints] +# flags: --warn-unreachable +from enum import Enum +from typing_extensions import Literal, Final + +class Foo(Enum): + A = 1 + B = 2 + C = 3 + + def __lt__(self, other: object) -> bool: return True + +x0: Foo +x1: Foo +x2: Foo +x3: Foo +x4: Foo +x5: Foo + +if x0 is x1 is Foo.A is x2 < x3 is Foo.B is x4 is x5: + reveal_type(x0) # N: Revealed type is 'Literal[__main__.Foo.A]' + reveal_type(x1) # N: Revealed type is 'Literal[__main__.Foo.A]' + reveal_type(x2) # N: Revealed type is 'Literal[__main__.Foo.A]' + + reveal_type(x3) # N: Revealed type is 'Literal[__main__.Foo.B]' + reveal_type(x4) # N: Revealed type is 'Literal[__main__.Foo.B]' + reveal_type(x5) # N: Revealed type is 'Literal[__main__.Foo.B]' +else: + reveal_type(x0) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]' + reveal_type(x1) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]' + reveal_type(x2) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]' + + reveal_type(x3) # N: Revealed type is 'Union[Literal[__main__.Foo.A], Literal[__main__.Foo.C]]' + reveal_type(x4) # N: Revealed type is 'Union[Literal[__main__.Foo.A], Literal[__main__.Foo.C]]' + reveal_type(x5) # N: Revealed type is 'Union[Literal[__main__.Foo.A], Literal[__main__.Foo.C]]' +[builtins fixtures/primitives.pyi] diff --git a/test-data/unit/check-optional.test b/test-data/unit/check-optional.test index 4b18cb59d1a7..9c40d550699e 100644 --- a/test-data/unit/check-optional.test +++ b/test-data/unit/check-optional.test @@ -528,6 +528,28 @@ else: reveal_type(x) # N: Revealed type is 'Union[builtins.str, builtins.int, None]' [builtins fixtures/ops.pyi] +[case testInferEqualsNotOptionalWithMultipleArgs] +from typing import Optional +x: Optional[int] +y: Optional[int] +if x == y == 1: + reveal_type(x) # N: Revealed type is 'builtins.int' + reveal_type(y) # N: Revealed type is 'builtins.int' +else: + reveal_type(x) # N: Revealed type is 'Union[builtins.int, None]' + reveal_type(y) # N: Revealed type is 'Union[builtins.int, None]' + +class A: pass +a: Optional[A] +b: Optional[A] +if a == b == object(): + reveal_type(a) # N: Revealed type is '__main__.A' + reveal_type(b) # N: Revealed type is '__main__.A' +else: + reveal_type(a) # N: Revealed type is 'Union[__main__.A, None]' + reveal_type(b) # N: Revealed type is 'Union[__main__.A, None]' +[builtins fixtures/ops.pyi] + [case testWarnNoReturnWorksWithStrictOptional] # flags: --warn-no-return def f() -> None: