diff --git a/mypy/checker.py b/mypy/checker.py index 96b55f321a73c..94f4d0dd5e440 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -4366,8 +4366,9 @@ def check_lvalue( types = [ self.check_lvalue(sub_expr)[0] or # This type will be used as a context for further inference of rvalue, - # we put Uninhabited if there is no information available from lvalue. - UninhabitedType() + # we put AnyType if there is no information available from lvalue. + AnyType(TypeOfAny.unannotated) + # UninhabitedType() fails testInferenceNestedTuplesFromGenericIterable for sub_expr in lvalue.items ] lvalue_type = TupleType(types, self.named_type("builtins.tuple")) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 73282c94be4eb..11dd52aacce6c 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -18,18 +18,32 @@ from mypy.checker_shared import ExpressionCheckerSharedApi from mypy.checkmember import analyze_member_access, has_operator from mypy.checkstrformat import StringFormatterChecker -from mypy.erasetype import erase_type, remove_instance_last_known_values, replace_meta_vars +from mypy.constraints import ( + SUBTYPE_OF, + SUPERTYPE_OF, + Constraint, + infer_constraints, + infer_constraints_for_callable, +) +from mypy.erasetype import ( + erase_type, + remove_instance_last_known_values, + replace_meta_vars, + replace_typevar, +) from mypy.errors import ErrorInfo, ErrorWatcher, report_internal_error from mypy.expandtype import ( expand_type, expand_type_by_instance, freshen_all_functions_type_vars, freshen_function_type_vars, + get_freshened_tvar_mapping, ) -from mypy.infer import ArgumentInferContext, infer_function_type_arguments, infer_type_arguments +from mypy.infer import ArgumentInferContext, infer_function_type_arguments +from mypy.join import join_types from mypy.literals import literal from mypy.maptype import map_instance_to_supertype -from mypy.meet import is_overlapping_types, narrow_declared_type +from mypy.meet import is_overlapping_types, meet_types, narrow_declared_type from mypy.message_registry import ErrorMessage from mypy.messages import MessageBuilder, format_type from mypy.nodes import ( @@ -37,7 +51,10 @@ ARG_POS, ARG_STAR, ARG_STAR2, + CONTRAVARIANT, + COVARIANT, IMPLICITLY_ABSTRACT, + INVARIANT, LAMBDA_NAME, LITERAL_TYPE, REVEAL_LOCALS, @@ -110,9 +127,11 @@ Plugin, ) from mypy.semanal_enum import ENUM_BASES +from mypy.solve import solve_constraints from mypy.state import state from mypy.subtypes import ( find_member, + infer_variance_in_expr, is_equivalent, is_same_type, is_subtype, @@ -140,7 +159,6 @@ get_type_vars, is_literal_type_like, make_simplified_union, - simple_literal_type, true_only, try_expanding_sum_type_to_union, try_getting_str_literals, @@ -190,12 +208,7 @@ is_named_instance, split_with_prefix_and_suffix, ) -from mypy.types_utils import ( - is_generic_instance, - is_overlapping_none, - is_self_type_like, - remove_optional, -) +from mypy.types_utils import is_generic_instance, is_self_type_like, remove_optional from mypy.typestate import type_state from mypy.typevars import fill_typevars from mypy.visitor import ExpressionVisitor @@ -231,6 +244,9 @@ "builtins.memoryview", } +HACKS: bool = False +PREFER_INNER_OVER_OUTER: bool = False + class TooManyUnions(Exception): """Indicates that we need to stop splitting unions in an attempt @@ -277,6 +293,9 @@ class ExpressionChecker(ExpressionVisitor[Type], ExpressionCheckerSharedApi): msg: MessageBuilder # Type context for type inference type_context: list[Type | None] + expr_context: list[Expression] + # constraints for the type context, used for type inference + constraint_context: list[list[Constraint]] # cache resolved types in some cases resolved_type: dict[Expression, ProperType] @@ -303,6 +322,8 @@ def __init__( # time for nested expressions. self.in_expression = False self.type_context = [None] + self.constraint_context = [[]] + self.expr_context = [] # Temporary overrides for expression types. This is currently # used by the union math in overloads. @@ -1571,7 +1592,16 @@ def check_call( if overloaded_result is not None: return overloaded_result - return self.check_callable_call( + _show( + f"\n=== CHECKING ===================" + f"\n\tcallee={callee}" + f"\n\targs={args} " + f"\n\targ_kinds={arg_kinds}" + f"\n\ttype_context={self.type_context}" + f"\n\tconstraints={self.constraint_context}" + ) + + result = self.check_callable_call( callee, args, arg_kinds, @@ -1581,6 +1611,14 @@ def check_call( callable_name, object_type, ) + _show( + f"\n=== RESULT ===================" + f"\n\n\tret_type={result[0]}" + f"\n\tinferred_callee={result[1]}" + ) + + return result + elif isinstance(callee, Overloaded): return self.check_overload_call( callee, args, arg_kinds, arg_names, callable_name, object_type, context @@ -1662,6 +1700,7 @@ def check_callable_call( See the docstring of check_call for more information. """ + # self.expr_cache.clear() # FOR DEBUGGING # Always unpack **kwargs before checking a call. callee = callee.with_unpacked_kwargs().with_normalized_var_args() if callable_name is None and callee.name: @@ -1734,10 +1773,6 @@ def check_callable_call( freeze_all_type_vars(fresh_ret_type) callee = callee.copy_modified(ret_type=fresh_ret_type) - if callee.is_generic(): - callee = freshen_function_type_vars(callee) - callee = self.infer_function_type_arguments_using_context(callee, context) - formal_to_actual = map_actuals_to_formals( arg_kinds, arg_names, @@ -1750,6 +1785,24 @@ def check_callable_call( need_refresh = any( isinstance(v, (ParamSpecType, TypeVarTupleType)) for v in callee.variables ) + if HACKS: + tvmap = get_freshened_tvar_mapping(callee) + # apply the new tvars (to callee and context!) + callee = expand_type(callee, tvmap).copy_modified(variables=list(tvmap.values())) + # update the constraints with the new tvars. + self.constraint_context[-1] = [ + Constraint( + tvmap.get(c.type_var, c.original_type_var), + c.op, + expand_type(c.target, tvmap), + ) + for c in self.constraint_context[-1] + ] + if self.type_context[-1] is not None: + self.type_context[-1] = expand_type(self.type_context[-1], tvmap) + else: + callee = freshen_function_type_vars(callee) + callee = self.infer_function_type_arguments( callee, args, arg_kinds, arg_names, formal_to_actual, need_refresh, context ) @@ -1936,6 +1989,7 @@ def infer_arg_types_in_context( args: list[Expression], arg_kinds: list[ArgKind], formal_to_actual: list[list[int]], + constraints: list[Constraint] | None = None, ) -> list[Type]: """Infer argument expression types using a callable type as context. @@ -1956,7 +2010,7 @@ def infer_arg_types_in_context( # cases. A cleaner alternative would be to switch to single bin type # inference, but this is a lot of work. old = self.infer_more_unions_for_recursive_type(arg_type) - res[ai] = self.accept(args[ai], arg_type) + res[ai] = self.accept(args[ai], arg_type, constraints=constraints) # We need to manually restore union inference state, ugh. type_state.infer_unions = old @@ -1967,9 +2021,9 @@ def infer_arg_types_in_context( assert all(tp is not None for tp in res) return cast(list[Type], res) - def infer_function_type_arguments_using_context( - self, callable: CallableType, error_context: Context - ) -> CallableType: + def infer_constraints_from_context( + self, callee: CallableType, error_context: Context + ) -> list[Constraint]: """Unify callable return type to type context to infer type vars. For example, if the return type is set[t] where 't' is a type variable @@ -1978,23 +2032,23 @@ def infer_function_type_arguments_using_context( """ ctx = self.type_context[-1] if not ctx: - return callable + return [] # The return type may have references to type metavariables that # we are inferring right now. We must consider them as indeterminate # and they are not potential results; thus we replace them with the # special ErasedType type. On the other hand, class type variables are # valid results. - erased_ctx = replace_meta_vars(ctx, ErasedType()) - ret_type = callable.ret_type - if is_overlapping_none(ret_type) and is_overlapping_none(ctx): - # If both the context and the return type are optional, unwrap the optional, - # since in 99% cases this is what a user expects. In other words, we replace - # Optional[T] <: Optional[int] - # with - # T <: int - # while the former would infer T <: Optional[int]. - ret_type = remove_optional(ret_type) - erased_ctx = remove_optional(erased_ctx) + proper_ctx = get_proper_type(ctx) + proper_ret = get_proper_type(callee.ret_type) + if isinstance(proper_ret, UnionType) and isinstance(proper_ctx, UnionType): + # If both the context and the return type are unions, we simplify shared items + # e.g. T | None <: int | None => T <: int + # since the former would infer T <: int | None. + # whereas the latter would infer the more precise T <: int. + new_ret = [val for val in proper_ret.items if val not in proper_ctx.items] + new_ctx = [val for val in proper_ctx.items if val not in proper_ret.items] + proper_ret = make_simplified_union(new_ret) + proper_ctx = make_simplified_union(new_ctx) # # TODO: Instead of this hack and the one below, we need to use outer and # inner contexts at the same time. This is however not easy because of two @@ -2005,7 +2059,9 @@ def infer_function_type_arguments_using_context( # variables in an expression are inferred at the same time. # (And this is hard, also we need to be careful with lambdas that require # two passes.) - proper_ret = get_proper_type(ret_type) + ctx = proper_ctx + ctx = update_type_context_using_constraints(ctx, self.constraint_context[-1]) + if ( isinstance(proper_ret, TypeVarType) or isinstance(proper_ret, UnionType) @@ -2035,22 +2091,9 @@ def infer_function_type_arguments_using_context( # TODO: we may want to add similar exception if all arguments are lambdas, since # in this case external context is almost everything we have. if not is_generic_instance(ctx) and not is_literal_type_like(ctx): - return callable.copy_modified() - args = infer_type_arguments( - callable.variables, ret_type, erased_ctx, skip_unsatisfied=True - ) - # Only substitute non-Uninhabited and non-erased types. - new_args: list[Type | None] = [] - for arg in args: - if has_uninhabited_component(arg) or has_erased_component(arg): - new_args.append(None) - else: - new_args.append(arg) - # Don't show errors after we have only used the outer context for inference. - # We will use argument context to infer more variables. - return self.apply_generic_arguments( - callable, new_args, error_context, skip_unsatisfied=True - ) + return [] + + return infer_constraints(proper_ret, ctx, SUBTYPE_OF) def infer_function_type_arguments( self, @@ -2069,14 +2112,67 @@ def infer_function_type_arguments( Return a derived callable type that has the arguments applied. """ if self.chk.in_checked_function(): + # compute the outer solution + outer_constraints = self.infer_constraints_from_context(callee_type, context) + minimize = True + maximize = False + # detect if we are in an outermost call context. + # if len(self.type_context) >=2 and self.type_context[-2] is None and self.type_context[-1] is not None: + # minimize = False + # maximize = True + # if len(self.type_context) >= 2 and self.type_context[-2] is None and self.type_context[-1] is not None: + # minimize = False + # maximize = True + + outer_solution, _ = solve_constraints( + callee_type.variables, + outer_constraints, # + self.constraint_context[-1], + strict=self.chk.in_checked_function(), + allow_polymorphic=False, + minimize=minimize, + maximize=maximize, + ) + outer_solution = filter_solution(outer_solution) + outer_callee = self.apply_generic_arguments( + callee_type, outer_solution, context, skip_unsatisfied=True + ) + + # compute the trivial constraints + trivial_constraints = get_trivial_constraints(callee_type.variables) + + # _show( + # f"\n=== DEBUG ============================" + # f"\ninfer_function_type_arguments: " + # f"\n\t{callee_type=}" + # f"\n\t{callee_type.special_sig=}" + # f"\n\t{self.type_context=}" + # f"\n\t{self.constraint_context=}" + # f"\n\t{trivial_constraints=}" + # f"\n\t{outer_constraints=}" + # f"\n\t{outer_solution=}" + # f"\n\t{outer_callee=}" + # ) + # Disable type errors during type inference. There may be errors # due to partial available context information at this time, but # these errors can be safely ignored as the arguments will be # inferred again later. with self.msg.filter_errors(): arg_types = self.infer_arg_types_in_context( - callee_type, args, arg_kinds, formal_to_actual + callee_type, + args, + arg_kinds, + formal_to_actual, + # Adding the trivial constraints ensures the context is always non-empty + # relevant for `testWideOuterContext` unit tests. + constraints=outer_constraints + trivial_constraints, ) + _show( + f"\n===== INFERRED ARGS =====" + f"\n\t{callee_type.arg_types}" + f"\n\tconstraints={outer_constraints + trivial_constraints}" + f"\n\t{arg_types=}" + ) arg_pass_nums = self.get_arg_infer_passes( callee_type, args, arg_types, formal_to_actual, len(args) @@ -2089,15 +2185,184 @@ def infer_function_type_arguments( else: pass1_args.append(arg) - inferred_args, _ = infer_function_type_arguments( - callee_type, - pass1_args, - arg_kinds, - arg_names, - formal_to_actual, - context=self.argument_infer_context(), - strict=self.chk.in_checked_function(), - ) + if True: # NEW CODE + # compute the inner solution + _inner_constraints = infer_constraints_for_callable( + callee_type, + pass1_args, + arg_kinds, + arg_names, + formal_to_actual, + context=self.argument_infer_context(), + ) + inner_upper, inner_lower = get_upper_and_lower(_inner_constraints) + # inner_upper = [ + # Constraint(c.original_type_var, c.op, forget_last_known_value(c.target)) + # # Constraint(c.original_type_var, c.op, use_last_known_value(c.target)) + # for c in inner_upper + # ] + # inner_lower = [ + # # Constraint(c.original_type_var, c.op, forget_last_known_value(c.target)) + # Constraint(c.original_type_var, c.op, use_last_known_value(c.target)) + # for c in inner_lower + # ] + inner_constraints = inner_upper + inner_lower + + # HACK: convert "Literal?" constraints to their non-literal versions. + # relevant for `testLiteral*` tests. + inner_constraints = [ + # Constraint(c.original_type_var, c.op, forget_last_known_value(c.target)) + Constraint(c.original_type_var, c.op, c.target) + for c in _inner_constraints + ] + inner_solution, _ = solve_constraints( + callee_type.variables, + inner_constraints + trivial_constraints, + strict=self.chk.in_checked_function(), + allow_polymorphic=False, + minimize=True, # <- essentially skips variables without lower bounds + ) + inner_solution = filter_solution(inner_solution) + inner_callee = self.apply_generic_arguments( + callee_type, inner_solution, context, skip_unsatisfied=True + ) + + # compute the joint solution using both inner and outer constraints. + # NOTE: The order of constraints is important here! + # solve(outer + inner) and solve(inner + outer) may yield different results. + # see https://github.com/python/mypy/issues/19551 + joint_constraints = outer_constraints + trivial_constraints + inner_constraints + joint_solution, _ = solve_constraints( + callee_type.variables, + joint_constraints, + strict=self.chk.in_checked_function(), + allow_polymorphic=False, + minimize=True, # <- essentially skips variables without lower bounds + ) + # filter out non-solutions containing "erased" or "uninhabited" + joint_solution = filter_solution(joint_solution) + joint_callee = self.apply_generic_arguments( + callee_type, joint_solution, context, skip_unsatisfied=True + ) + + # determine if solution was successful: + if is_proper_solution(callee_type, joint_solution): + extra_constraints = outer_constraints + inner_constraints + trivial_constraints + elif is_proper_solution(callee_type, outer_solution): + extra_constraints = outer_constraints + trivial_constraints + else: + extra_constraints = outer_constraints + inner_constraints + + # determine which solution to take + use_joint = all( + # only use joint if it solved at least the same variables as the outer solution + # That is if: ```joint[k]=None ⟹ outer[k]=None``` + # or, equivalently, if ```not (joint[k]=None) or outer[k]=None``` + (j is not None or o is None) + for j, o in zip(joint_solution, outer_solution) + ) + # currently disabled + use_inner = PREFER_INNER_OVER_OUTER and not use_joint + + if use_joint: + inferred_args = joint_solution + elif use_inner: + inferred_args = inner_solution + else: + inferred_args = outer_solution + + # if ( + # callee_type.special_sig == "dict" + # and len(inferred_args) == 2 + # and (ARG_NAMED in arg_kinds or ARG_STAR2 in arg_kinds) + # ): + # # HACK: Infer str key type for dict(...) with keyword args. The type system + # # can't represent this so we special case it, as this is a pretty common + # # thing. This doesn't quite work with all possible subclasses of dict + # # if they shuffle type variables around, as we assume that there is a 1-1 + # # correspondence with dict type variables. This is a marginal issue and + # # a little tricky to fix so it's left unfixed for now. + # first_arg = get_proper_type(inferred_args[0]) + # if first_arg is None or isinstance(first_arg, UninhabitedType): + # inferred_args[0] = self.named_type("builtins.str") + # elif not first_arg or not is_subtype( + # self.named_type("builtins.str"), first_arg + # ): + # self.chk.fail( + # message_registry.KEYWORD_ARGUMENT_REQUIRES_STR_KEY_TYPE, context + # ) + + # _show( + # f"\n=== DEBUG ============================" + # f"\n\t{callee_type=}" + # f"\n\t{callee_type.special_sig=}" + # f"\n\t{self.type_context=}" + # f"\n\t{self.constraint_context=}" + # f"\n\t{arg_types=}" + # f"\n\t{pass1_args=}" + # f"\n\t{outer_constraints=}" + # f"\n\t{inner_constraints=}" + # f"\n\t{trivial_constraints=}" + # f"\n\t{joint_constraints=}" + # f"\n\t{outer_solution=}" + # f"\n\t{inner_solution=}" + # f"\n\t{joint_solution=}" + # f"\n\t{outer_callee=}" + # f"\n\t{inner_callee=}" + # f"\n\t{joint_callee=}" + # f"\n\t{use_joint=}" + # f"\n\t{inferred_args=}" + # f"\n" + # f"\n\tresult={self.apply_generic_arguments(callee_type, inferred_args, context, skip_unsatisfied=True)}" + # ) + + if use_joint: + inferred_args = joint_solution + elif use_inner: + # If we can use the inner solution, apply it. + callee_type = self.apply_generic_arguments( + callee_type, inner_solution, context, skip_unsatisfied=True + ) + # recompute the outer solution + new_outer_constraints = self.infer_constraints_from_context( + callee_type, context + ) + new_outer_solution, _ = solve_constraints( + callee_type.variables, + new_outer_constraints, + strict=self.chk.in_checked_function(), + allow_polymorphic=False, + ) + inferred_args = filter_solution(new_outer_solution) + # _show(f"Two stage inference: \n\t{callee_type=}\n\t{inferred_args=}") + else: + # If we cannot use the joint solution, fall back to a 2 stage inference, + # by first applying the outer solution, and then inferring the inner again + callee_type = self.apply_generic_arguments( + callee_type, outer_solution, context, skip_unsatisfied=True + ) + + # QUESTION: Do we need to recompute formal_to_actual, arg_types and pass1_args here??? + # recompute and apply inner solution. + new_inner_constraints = infer_constraints_for_callable( + callee_type, + pass1_args, + arg_kinds, + arg_names, + formal_to_actual, + context=self.argument_infer_context(), + ) + new_inner_solution, _ = solve_constraints( + callee_type.variables, + new_inner_constraints + trivial_constraints, + strict=self.chk.in_checked_function(), + allow_polymorphic=False, + minimize=True, + ) + inferred_args = filter_solution(new_inner_solution) + # _show(f"Two stage inference: \n\t{callee_type=}\n\t{inferred_args=}") + else: # END NEW CODE + pass if 2 in arg_pass_nums: # Second pass of type inference. @@ -2111,6 +2376,15 @@ def infer_function_type_arguments( need_refresh, context, ) + # _show( + # f"\n=== DEBUG ============================" + # f"\nPASS2: " + # f"\n\t{callee_type=}" + # f"\n\t{callee_type.special_sig=}" + # f"\n\t{self.type_context=}" + # f"\n\t{self.constraint_context=}" + # f"\n\t{inferred_args=}" + # ) if ( callee_type.special_sig == "dict" @@ -2124,16 +2398,13 @@ def infer_function_type_arguments( # correspondence with dict type variables. This is a marginal issue and # a little tricky to fix so it's left unfixed for now. first_arg = get_proper_type(inferred_args[0]) - if isinstance(first_arg, (NoneType, UninhabitedType)): + if first_arg is None or isinstance(first_arg, UninhabitedType): inferred_args[0] = self.named_type("builtins.str") elif not first_arg or not is_subtype(self.named_type("builtins.str"), first_arg): self.chk.fail(message_registry.KEYWORD_ARGUMENT_REQUIRES_STR_KEY_TYPE, context) - if not self.chk.options.old_type_inference and any( - a is None - or isinstance(get_proper_type(a), UninhabitedType) - or set(get_type_vars(a)) & set(callee_type.variables) - for a in inferred_args + if not self.chk.options.old_type_inference and not is_proper_solution( + callee_type, inferred_args ): if need_refresh: # Technically we need to refresh formal_to_actual after *each* inference pass, @@ -2159,17 +2430,18 @@ def infer_function_type_arguments( context=self.argument_infer_context(), strict=self.chk.in_checked_function(), allow_polymorphic=True, + extra_constraints=extra_constraints, + minimize=True, ) poly_callee_type = self.apply_generic_arguments( callee_type, poly_inferred_args, context ) + # _show(f"\n\t{poly_inferred_args=}\n\t{free_vars=}\n\t{poly_callee_type=}") # Try applying inferred polymorphic type if possible, e.g. Callable[[T], T] can # be interpreted as def [T] (T) -> T, but dict[T, T] cannot be expressed. applied = applytype.apply_poly(poly_callee_type, free_vars) - if applied is not None and all( - a is not None and not isinstance(get_proper_type(a), UninhabitedType) - for a in poly_inferred_args - ): + if applied is not None and all(is_solution(a) for a in poly_inferred_args): + # _show(f"\nTriggered polymorphic inference:\n\t{applied=}") freeze_all_type_vars(applied) return applied # If it didn't work, erase free variables as uninhabited, to avoid confusing errors. @@ -2185,11 +2457,17 @@ def infer_function_type_arguments( ) for a in poly_inferred_args ] + result = self.apply_inferred_arguments(callee_type, inferred_args, context) + + # _show(f"=== OUTPUT ============================result={result}") + + return result else: # In dynamically typed functions use implicit 'Any' types for # type variables. inferred_args = [AnyType(TypeOfAny.unannotated)] * len(callee_type.variables) - return self.apply_inferred_arguments(callee_type, inferred_args, context) + result = self.apply_inferred_arguments(callee_type, inferred_args, context) + return result def infer_function_type_arguments_pass2( self, @@ -5100,6 +5378,14 @@ def check_lst_expr(self, e: ListExpr | SetExpr | TupleExpr, fullname: str, tag: if t: return t + # take context into account. + # self.type_context[-1] should be a list-type + # for example list[T] or list[T | int], list[T | S], etc. + # self.constraint_context[-1] may contain constraints on those variables. + # What we should do is consider a constructor: + # def [T] (T) -> list[T] + # then match the T against the arguments of the type_context[-1]. + # Translate into type checking a generic function call. # Used for list and set expressions, as well as for tuples # containing star expressions that don't refer to a @@ -5112,6 +5398,22 @@ def check_lst_expr(self, e: ListExpr | SetExpr | TupleExpr, fullname: str, tag: upper_bound=self.object_type(), default=AnyType(TypeOfAny.from_omitted_generics), ) + + if HACKS: + if self.type_context[-1] is not None: + new_type = self.chk.named_generic_type(fullname, [tv]) + ctx_vars = get_type_vars(self.type_context[-1]) + ctx_cons = self.constraint_context[-1] + ctx_sol, _ = solve_constraints(ctx_vars, ctx_cons) + # apply the solution + ctx_tvmap = { + v.id: sol for i, v in enumerate(ctx_vars) if (sol := ctx_sol[i]) is not None + } + self.type_context[-1] = expand_type(self.type_context[-1], ctx_tvmap) + outer_constraints = infer_constraints(new_type, self.type_context[-1], SUBTYPE_OF) + outer_solution, _ = solve_constraints([tv], outer_constraints) + self.constraint_context[-1] = outer_constraints + constructor = CallableType( [tv], [nodes.ARG_STAR], @@ -5886,63 +6188,26 @@ def visit_conditional_expr(self, e: ConditionalExpr, allow_none_return: bool = F elif else_map is None: self.msg.redundant_condition_in_if(True, e.cond) + if ctx is None: + # When no context is provided, compute each branch individually, and + # use the union of the results as artificial context. Important for: + # - testUnificationDict + # - testConditionalExpressionWithEmpty + ctx_if_type = self.analyze_cond_branch( + if_map, e.if_expr, context=ctx, allow_none_return=allow_none_return + ) + ctx_else_type = self.analyze_cond_branch( + else_map, e.else_expr, context=ctx, allow_none_return=allow_none_return + ) + ctx = make_simplified_union([ctx_if_type, ctx_else_type]) + if_type = self.analyze_cond_branch( if_map, e.if_expr, context=ctx, allow_none_return=allow_none_return ) - - # we want to keep the narrowest value of if_type for union'ing the branches - # however, it would be silly to pass a literal as a type context. Pass the - # underlying fallback type instead. - if_type_fallback = simple_literal_type(get_proper_type(if_type)) or if_type - - # Analyze the right branch using full type context and store the type - full_context_else_type = self.analyze_cond_branch( + else_type = self.analyze_cond_branch( else_map, e.else_expr, context=ctx, allow_none_return=allow_none_return ) - if not mypy.checker.is_valid_inferred_type(if_type, self.chk.options): - # Analyze the right branch disregarding the left branch. - else_type = full_context_else_type - # we want to keep the narrowest value of else_type for union'ing the branches - # however, it would be silly to pass a literal as a type context. Pass the - # underlying fallback type instead. - else_type_fallback = simple_literal_type(get_proper_type(else_type)) or else_type - - # If it would make a difference, re-analyze the left - # branch using the right branch's type as context. - if ctx is None or not is_equivalent(else_type_fallback, ctx): - # TODO: If it's possible that the previous analysis of - # the left branch produced errors that are avoided - # using this context, suppress those errors. - if_type = self.analyze_cond_branch( - if_map, - e.if_expr, - context=else_type_fallback, - allow_none_return=allow_none_return, - ) - - elif if_type_fallback == ctx: - # There is no point re-running the analysis if if_type is equal to ctx. - # That would be an exact duplicate of the work we just did. - # This optimization is particularly important to avoid exponential blowup with nested - # if/else expressions: https://github.com/python/mypy/issues/9591 - # TODO: would checking for is_proper_subtype also work and cover more cases? - else_type = full_context_else_type - else: - # Analyze the right branch in the context of the left - # branch's type. - else_type = self.analyze_cond_branch( - else_map, - e.else_expr, - context=if_type_fallback, - allow_none_return=allow_none_return, - ) - - # In most cases using if_type as a context for right branch gives better inferred types. - # This is however not the case for literal types, so use the full context instead. - if is_literal_type_like(full_context_else_type) and not is_literal_type_like(else_type): - else_type = full_context_else_type - res: Type = make_simplified_union([if_type, else_type]) if has_uninhabited_component(res) and not isinstance( get_proper_type(self.type_context[-1]), UnionType @@ -5998,11 +6263,13 @@ def accept( allow_none_return: bool = False, always_allow_any: bool = False, is_callee: bool = False, + constraints: list[Constraint] | None = None, ) -> Type: """Type check a node in the given type context. If allow_none_return is True and this expression is a call, allow it to return None. This applies only to this expression and not any subexpressions. """ + # self.expr_cache.clear() # debugging if node in self.type_overrides: # This branch is very fast, there is no point timing it. return self.type_overrides[node] @@ -6012,7 +6279,37 @@ def accept( t0 = time.perf_counter_ns() self.in_expression = True record_time = True + self.constraint_context.append(constraints or []) + # type_context = self.update_type_context_using_constraints(type_context) self.type_context.append(type_context) + self.expr_context.append(node) + + # fix the type context using the contextual constraints + if HACKS and type_context is not None: + # compute solved outer context. + ctx_vars = get_type_vars(type_context) + ctx_cons = self.constraint_context[-1] + + # add trivial constraints: + # naive_constraints = [ + # Constraint(t, SUBTYPE_OF, t.upper_bound) + # for t in ctx_vars + # if isinstance(t, TypeVarType) + # ] + # ctx_cons += naive_constraints + + ctx_sol, _ = solve_constraints(ctx_vars, ctx_cons, minimize=True) + + # Skip types that did not resolve + ctx_sol = filter_solution(ctx_sol) + + # apply the solution + ctx_tvmap = { + v.id: sol for i, v in enumerate(ctx_vars) if (sol := ctx_sol[i]) is not None + } + self.type_context[-1] = expand_type(type_context, ctx_tvmap) + type_context = self.type_context[-1] + old_is_callee = self.is_callee self.is_callee = is_callee try: @@ -6052,6 +6349,8 @@ def accept( ) self.is_callee = old_is_callee self.type_context.pop() + self.expr_context.pop() + self.constraint_context.pop() assert typ is not None self.chk.store_type(node, typ) @@ -6765,3 +7064,168 @@ def is_type_type_context(context: Type | None) -> bool: if isinstance(context, UnionType): return any(is_type_type_context(item) for item in context.items) return False + + +def is_solution(t: Type | None) -> bool: + """Whether t is a proper solution.""" + return not ( + t is None + or isinstance(get_proper_type(t), UninhabitedType) # uninhabited types are not solutions + or has_erased_component(t) # list[] is not a solution + # note: list[Never] is inhabited and can be a solution + ) + + +def filter_solution(args: list[Type | None]) -> list[Type | None]: + r"""Filter out non-solutions containing "erased" or "uninhabited".""" + return [arg if is_solution(arg) else None for arg in args] + + +def forget_last_known_value(t: Type, /) -> Type: + """Forget the last known value of a type.""" + p_t = get_proper_type(t) + return p_t.copy_modified(last_known_value=None) if isinstance(p_t, Instance) else p_t + + +def use_last_known_value(t: Type, /) -> Type: + """Use the last known value of a type, if it has one.""" + p_t = get_proper_type(t) + if isinstance(p_t, Instance) and p_t.last_known_value is not None: + return p_t.last_known_value + return t # No last known value, return the original type unchanged + + +# – + + +def get_upper_and_lower( + constraints: list[Constraint], +) -> tuple[list[Constraint], list[Constraint]]: + """Get upper and lower bounds from a list of constraints.""" + upper = [] + lower = [] + for constraint in constraints: + if constraint.op == SUBTYPE_OF: + upper.append(constraint) + elif constraint.op == SUPERTYPE_OF: + lower.append(constraint) + else: + raise ValueError(f"Unexpected constraint operator: {constraint.op}") + return upper, lower + + +def _show(*args: object) -> None: + if False: + print(*args) + + +def get_trivial_constraints( + tvars: Sequence[TypeVarLikeType], /, *, filter_object: bool = True +) -> Sequence[Constraint]: + """Every Type Variable is a subtype of its upper bound.""" + # each TVar is a subtype of its upper bound + constraints: list[Constraint] = [] + for t in tvars: + upper_bound = get_proper_type(t.upper_bound) + + # An upper bound of 'object' is not informative, and usually just means no + # constraint was specified. We skip these because the solver picks the upper + # bound as a solution if there are no lower constraints. + if ( + filter_object + and isinstance(upper_bound, Instance) + and upper_bound.type.fullname == "builtins.object" + ): + pass + else: + constraints.append(Constraint(t, SUBTYPE_OF, upper_bound)) + + if isinstance(t, TypeVarType) and t.values: + constraints.append(Constraint(t, SUBTYPE_OF, make_simplified_union(t.values))) + return constraints + + +def get_upper_bounds( + tvars: Sequence[TypeVarLikeType], constraints: Sequence[Constraint] +) -> dict[TypeVarId, Type]: + # for each tvar, find all upper constraints on it. + # then, update the upper bound of the tvar to be the intersection of + # the upper bounds. + upper_bounds: dict[TypeVarId, Type] = {} + for tvar in tvars: + relevant_constraints = [ + c for c in constraints if c.type_var == tvar.id and c.op == SUBTYPE_OF + ] + top = tvar.upper_bound + for c in relevant_constraints: + top = meet_types(top, c.target) + + upper_bounds[tvar.id] = top + return upper_bounds + + +def get_lower_bounds( + tvars: Sequence[TypeVarLikeType], constraints: Sequence[Constraint] +) -> dict[TypeVarId, Type]: + # for each tvar, find all lower constraints on it. + # then, update the lower bound of the tvar to be the union of + # the lower bounds. + lower_bounds: dict[TypeVarId, Type] = {} + for tvar in tvars: + relevant_constraints = [ + c for c in constraints if c.type_var == tvar.id and c.op == SUPERTYPE_OF + ] + bottom: Type = UninhabitedType() + + for c in relevant_constraints: + bottom = join_types(bottom, c.target) + lower_bounds[tvar.id] = bottom + return lower_bounds + + +def update_type_context_using_constraints(typ: Type, /, constraints: list[Constraint]) -> Type: + tvars = get_all_type_vars(typ) + + # concatenate self.constraint_context, which is a list of lists + upper_bounds = get_upper_bounds(tvars, constraints) + lower_bounds = get_lower_bounds(tvars, constraints) + + for tvar in tvars: + if not tvar.id.is_meta_var(): + continue + + if isinstance(tvar, TypeVarType): + variance = infer_variance_in_expr(typ, tvar) + + if variance == INVARIANT: + typ = replace_typevar(typ, tvar.id, ErasedType()) + + elif variance == COVARIANT: + contextual_upper_bound = upper_bounds[tvar.id] + typ = replace_typevar(typ, tvar.id, contextual_upper_bound) + + elif variance == CONTRAVARIANT: + contextual_lower_bound = lower_bounds[tvar.id] + typ = replace_typevar(typ, tvar.id, contextual_lower_bound) + + else: + raise ValueError(f"Unexpected variance: {variance}") + + elif isinstance(tvar, TypeVarTupleType): + typ = replace_typevar(typ, tvar.id, ErasedType()) + elif isinstance(tvar, ParamSpecType): + typ = replace_typevar(typ, tvar.id, ErasedType()) + else: + # Add other branches if more TypeVarLikeType are added + raise TypeError(f"Unexpected type variable type: {tvar}") + + return typ + + +def is_proper_solution(callee_type: CallableType, inferred_args: Sequence[Type | None]) -> bool: + return not any( + a is None + or isinstance(get_proper_type(a), UninhabitedType) + or (set(get_type_vars(a)) & set(callee_type.variables)) + for a in inferred_args + ) diff --git a/mypy/constraints.py b/mypy/constraints.py index 96c0c7ccaf35e..1833ee31ecb71 100644 --- a/mypy/constraints.py +++ b/mypy/constraints.py @@ -77,11 +77,13 @@ class Constraint: """ type_var: TypeVarId + original_type_var: TypeVarLikeType op = 0 # SUBTYPE_OF or SUPERTYPE_OF target: Type def __init__(self, type_var: TypeVarLikeType, op: int, target: Type) -> None: self.type_var = type_var.id + self.original_type_var = type_var self.op = op # TODO: should we add "assert not isinstance(target, UnpackType)"? # UnpackType is a synthetic type, and is never valid as a constraint target. @@ -1356,7 +1358,11 @@ def visit_typeddict_type(self, template: TypedDictType) -> list[Constraint]: # NOTE: Non-matching keys are ignored. Compatibility is checked # elsewhere so this shouldn't be unsafe. for item_name, template_item_type, actual_item_type in template.zip(actual): - res.extend(infer_constraints(template_item_type, actual_item_type, self.direction)) + # Value type is invariant, so irrespective of the direction, + # we constrain both above and below. + # Fixes testTypedDictWideContext + res.extend(infer_constraints(template_item_type, actual_item_type, SUBTYPE_OF)) + res.extend(infer_constraints(template_item_type, actual_item_type, SUPERTYPE_OF)) return res elif isinstance(actual, AnyType): return self.infer_against_any(template.items.values(), actual) diff --git a/mypy/erasetype.py b/mypy/erasetype.py index 6645bcf916d90..163fd75e51f06 100644 --- a/mypy/erasetype.py +++ b/mypy/erasetype.py @@ -158,7 +158,18 @@ def erase_meta_id(id: TypeVarId) -> bool: return id.is_meta_var() -def replace_meta_vars(t: Type, target_type: Type) -> Type: +def replace_typevar(t: Type, tvar_id: TypeVarId, replacement: Type) -> Type: + """Replace type variable in a type with the target type.""" + + def replace_id(id: TypeVarId) -> bool: + return id == tvar_id + + return t.accept(TypeVarEraser(replace_id, replacement)) + + +def replace_meta_vars( + t: Type, target_type: Type, ids_to_replace: Container[TypeVarId] | None = None +) -> Type: """Replace unification variables in a type with the target type.""" return t.accept(TypeVarEraser(erase_meta_id, target_type)) @@ -233,6 +244,26 @@ def visit_type_alias_type(self, t: TypeAliasType) -> Type: return t.copy_modified(args=[a.accept(self) for a in t.args]) +class TypeVarSubstitutor(TypeTranslator): + """Substitute a type variable with a given replacement. + + Args: + erase_id: A callable that returns True if the type variable should be replaced. + If None, all type variables are replaced. + replacement: The type to replace the type variable with. + covariant_replacement (optional): The type to replace the type variable when it is used in a covariant position. + contravariant_replacement (optional): The type to replace the type variable when it is used in a contravariant position. + + Examples: + We have an upper bounded type variables `T <: MyType` + We know we have a type-var constraint `S <: T` + We can create a weaker constraint `S <: MyType` by using the upper bound + Likewise, when we have `S <: Callable[[T], R]`, we can create a weaker constraint `S <: Callable[[Never], R]`, + by using a lower bound of `T` which is `Never`. + So generally, we can use lower bounds for contravariant positions and upper bounds for covariant positions. + """ + + def remove_instance_last_known_values(t: Type) -> Type: return t.accept(LastKnownValueEraser()) diff --git a/mypy/expandtype.py b/mypy/expandtype.py index e2a42317141f4..e551e7b8f7970 100644 --- a/mypy/expandtype.py +++ b/mypy/expandtype.py @@ -137,6 +137,16 @@ def freshen_function_type_vars(callee: F) -> F: return cast(F, fresh_overload) +def get_freshened_tvar_mapping(callee: CallableType) -> dict[TypeVarId, TypeVarLikeType]: + """Substitute fresh type variables for generic function type variables.""" + assert isinstance(callee, CallableType) + tvmap: dict[TypeVarId, TypeVarLikeType] = {} + for v in callee.variables: + tv = v.new_unification_variable(v) + tvmap[v.id] = tv + return tvmap + + class HasGenericCallable(BoolTypeQuery): def __init__(self) -> None: super().__init__(ANY_STRATEGY) diff --git a/mypy/infer.py b/mypy/infer.py index cdc43797d3b16..603bef5e0b51b 100644 --- a/mypy/infer.py +++ b/mypy/infer.py @@ -8,6 +8,7 @@ from mypy.constraints import ( SUBTYPE_OF, SUPERTYPE_OF, + Constraint, infer_constraints, infer_constraints_for_callable, ) @@ -39,6 +40,8 @@ def infer_function_type_arguments( context: ArgumentInferContext, strict: bool = True, allow_polymorphic: bool = False, + extra_constraints: list[Constraint | None] = None, + minimize: bool = False, ) -> tuple[list[Type | None], list[TypeVarLikeType]]: """Infer the type arguments of a generic function. @@ -58,9 +61,12 @@ def infer_function_type_arguments( callee_type, arg_types, arg_kinds, arg_names, formal_to_actual, context ) + if extra_constraints: + constraints += extra_constraints + # Solve constraints. type_vars = callee_type.variables - return solve_constraints(type_vars, constraints, strict, allow_polymorphic) + return solve_constraints(type_vars, constraints, strict, allow_polymorphic, minimize=minimize) def infer_type_arguments( diff --git a/mypy/join.py b/mypy/join.py index 099df02680f06..425e97bf36901 100644 --- a/mypy/join.py +++ b/mypy/join.py @@ -743,6 +743,7 @@ def match_generic_callables(t: CallableType, s: CallableType) -> tuple[CallableT def join_similar_callables(t: CallableType, s: CallableType) -> CallableType: + # join(X₁ -> Y₁, X₂ -> Y₂) = meet(X₁, X₂) -> join(Y₁, Y₂) t, s = match_generic_callables(t, s) arg_types: list[Type] = [] for i in range(len(t.arg_types)): diff --git a/mypy/meet.py b/mypy/meet.py index 353af59367ad1..92cd37738e91f 100644 --- a/mypy/meet.py +++ b/mypy/meet.py @@ -1139,6 +1139,7 @@ def default(self, typ: Type) -> ProperType: def meet_similar_callables(t: CallableType, s: CallableType) -> CallableType: + # meet(X₁ -> Y₁, X₂ -> Y₂) = join(X₁, X₂) -> meet(Y₁, Y₂) from mypy.join import match_generic_callables, safe_join t, s = match_generic_callables(t, s) diff --git a/mypy/solve.py b/mypy/solve.py index fbbcac2520ad0..091d9bbf488e8 100644 --- a/mypy/solve.py +++ b/mypy/solve.py @@ -9,7 +9,7 @@ from mypy.constraints import SUBTYPE_OF, SUPERTYPE_OF, Constraint, infer_constraints, neg_op from mypy.expandtype import expand_type from mypy.graph_utils import prepare_sccs, strongly_connected_components, topsort -from mypy.join import join_type_list +from mypy.join import join_type_list, object_or_any_from_type from mypy.meet import meet_type_list, meet_types from mypy.subtypes import is_subtype from mypy.typeops import get_all_type_vars @@ -44,6 +44,8 @@ def solve_constraints( strict: bool = True, allow_polymorphic: bool = False, skip_unsatisfied: bool = False, + minimize: bool = False, + maximize: bool = False, ) -> tuple[list[Type | None], list[TypeVarLikeType]]: """Solve type constraints. @@ -82,7 +84,7 @@ def solve_constraints( if allow_polymorphic: if constraints: solutions, free_vars = solve_with_dependent( - vars + extra_vars, constraints, vars, originals + vars + extra_vars, constraints, vars, originals, minimize=minimize ) else: solutions = {} @@ -95,7 +97,7 @@ def solve_constraints( continue lowers = [c.target for c in cs if c.op == SUPERTYPE_OF] uppers = [c.target for c in cs if c.op == SUBTYPE_OF] - solution = solve_one(lowers, uppers) + solution = solve_one(lowers, uppers, minimize=minimize, maximize=maximize) # Do not leak type variables in non-polymorphic solutions. if solution is None or not get_vars( @@ -131,6 +133,7 @@ def solve_with_dependent( constraints: list[Constraint], original_vars: list[TypeVarId], originals: dict[TypeVarId, TypeVarLikeType], + minimize: bool = False, ) -> tuple[Solutions, list[TypeVarLikeType]]: """Solve set of constraints that may depend on each other, like T <: List[S]. @@ -182,13 +185,13 @@ def solve_with_dependent( solutions: dict[TypeVarId, Type | None] = {} for flat_batch in batches: - res = solve_iteratively(flat_batch, graph, lowers, uppers) + res = solve_iteratively(flat_batch, graph, lowers, uppers, minimize=minimize) solutions.update(res) return solutions, [free_solutions[tv] for tv in free_vars] def solve_iteratively( - batch: list[TypeVarId], graph: Graph, lowers: Bounds, uppers: Bounds + batch: list[TypeVarId], graph: Graph, lowers: Bounds, uppers: Bounds, minimize: bool = False ) -> Solutions: """Solve transitive closure sequentially, updating upper/lower bounds after each step. @@ -214,7 +217,7 @@ def solve_iteratively( break # Solve each solvable type variable separately. s_batch.remove(solvable_tv) - result = solve_one(lowers[solvable_tv], uppers[solvable_tv]) + result = solve_one(lowers[solvable_tv], uppers[solvable_tv], minimize=minimize) solutions[solvable_tv] = result if result is None: # TODO: support backtracking lower/upper bound choices and order within SCCs. @@ -256,7 +259,9 @@ def _join_sorted_key(t: Type) -> int: return 0 -def solve_one(lowers: Iterable[Type], uppers: Iterable[Type]) -> Type | None: +def solve_one( + lowers: Iterable[Type], uppers: Iterable[Type], minimize: bool = False, maximize: bool = False +) -> Type | None: """Solve constraints by finding by using meets of upper bounds, and joins of lower bounds.""" candidate: Type | None = None @@ -310,18 +315,44 @@ def solve_one(lowers: Iterable[Type], uppers: Iterable[Type]) -> Type | None: source_any = top if isinstance(p_top, AnyType) else bottom assert isinstance(source_any, ProperType) and isinstance(source_any, AnyType) return AnyType(TypeOfAny.from_another_any, source_any=source_any) - elif bottom is None: - if top: + + assert not (minimize and maximize) + + if minimize: # pick minimum solution + if bottom is None and top is None: + return None + elif bottom is None: + candidate = UninhabitedType() + elif top is None: + candidate = bottom + elif is_subtype(bottom, top): + candidate = bottom + else: + candidate = None + elif maximize: # choose "largest" solution + if bottom is None and top is None: + return None + elif bottom is None: + candidate = top + elif top is None: + assert p_bottom is not None + candidate = object_or_any_from_type(p_bottom) + elif is_subtype(bottom, top): candidate = top else: - # No constraints for type variable + candidate = None + else: # choose "best" solution + if bottom is None and top is None: return None - elif top is None: - candidate = bottom - elif is_subtype(bottom, top): - candidate = bottom - else: - candidate = None + elif bottom is None: + candidate = top + elif top is None: + candidate = bottom + elif is_subtype(bottom, top): + candidate = bottom + else: + candidate = None + return candidate diff --git a/mypy/subtypes.py b/mypy/subtypes.py index 7da258a827f33..7dc7673a25ca9 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -36,6 +36,7 @@ ) from mypy.options import Options from mypy.state import state +from mypy.typeops import get_all_type_vars from mypy.types import ( MYPYC_NATIVE_INT_NAMES, TUPLE_LIKE_INSTANCE_NAMES, @@ -61,6 +62,8 @@ TypedDictType, TypeOfAny, TypeType, + TypeVarId, + TypeVarLikeType, TypeVarTupleType, TypeVarType, TypeVisitor, @@ -2175,6 +2178,40 @@ def all_non_object_members(info: TypeInfo) -> set[str]: return members +def infer_variance_in_expr(type_form: Type, tvar: TypeVarLikeType) -> int: + r"""Infer the variance of the ith type variable in a type expression. + + Assume we have a type expression `TypeForm[T1, ..., Tn]` with type variables T1, ..., Tn. + + Then this method returns: + + 1. COVARIANT, if X <: T1 implies TypeForm[X, T2, ..., Tn] <: TypeForm[T1, T2, ..., Tn] + 2. CONTRAVARIANT, if X <: T1 implies TypeForm[X, T2, ..., Tn] :> TypeForm[T1, T2, ..., Tn] + 3. INVARIANT, if neither of the above holds + """ + # 0. If the type variable does not appear in the type expression, return INVARIANT. + if tvar not in get_all_type_vars(type_form): + return INVARIANT + + fresh_var = TypeVarType( + "X", + "X", + id=TypeVarId(-2), + values=[], + # Use other TypeVar as the upper bound + # This is not officially supported, but does seem to work? + upper_bound=tvar, + default=AnyType(TypeOfAny.from_omitted_generics), + ) + new_form = expand_type(type_form, {tvar.id: fresh_var}) + + if is_subtype(new_form, type_form): + return COVARIANT + if is_subtype(type_form, new_form): + return CONTRAVARIANT + return INVARIANT + + def infer_variance(info: TypeInfo, i: int) -> bool: """Infer the variance of the ith type variable of a generic class. diff --git a/mypy/types.py b/mypy/types.py index e0e897e04cadf..09ff6c1b776c0 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -3029,7 +3029,11 @@ def copy_modified( if item_names is not None: items = {k: v for (k, v) in items.items() if k in item_names} required_keys &= set(item_names) - return TypedDictType(items, required_keys, readonly_keys, fallback, self.line, self.column) + result = TypedDictType( + items, required_keys, readonly_keys, fallback, self.line, self.column + ) + result.to_be_mutated = self.to_be_mutated + return result def create_anonymous_fallback(self) -> Instance: anonymous = self.as_anonymous() diff --git a/mypyc/irbuild/specialize.py b/mypyc/irbuild/specialize.py index 0880c62bc7a58..576b7a7ebffd8 100644 --- a/mypyc/irbuild/specialize.py +++ b/mypyc/irbuild/specialize.py @@ -719,7 +719,9 @@ def get_literal_str(expr: Expression) -> str | None: if isinstance(expr, StrExpr): return expr.value elif isinstance(expr, RefExpr) and isinstance(expr.node, Var) and expr.node.is_final: - return str(expr.node.final_value) + final_value = expr.node.final_value + if final_value is not None: + return str(final_value) return None for i in range(len(exprs) - 1): diff --git a/mypyc/test-data/run-strings.test b/mypyc/test-data/run-strings.test index 6960b0a043038..6a62db6ee3ee0 100644 --- a/mypyc/test-data/run-strings.test +++ b/mypyc/test-data/run-strings.test @@ -412,9 +412,16 @@ def test_basics() -> None: [case testFStrings] import decimal from datetime import datetime +from typing import Final var = 'mypyc' num = 20 +final_known_at_compile_time: Final = 'hello' + +def final_value_setter() -> str: + return 'goodbye' + +final_unknown_at_compile_time: Final = final_value_setter() def test_fstring_basics() -> None: assert f'Hello {var}, this is a test' == "Hello mypyc, this is a test" @@ -451,6 +458,8 @@ def test_fstring_basics() -> None: inf_num = float('inf') assert f'{nan_num}, {inf_num}' == 'nan, inf' + assert f'{final_known_at_compile_time} {final_unknown_at_compile_time}' == 'hello goodbye' + # F-strings would be translated into ''.join[string literals, format method call, ...] in mypy AST. # Currently we are using a str.join specializer for f-string speed up. We might not cover all cases # and the rest ones should fall back to a normal str.join method call. diff --git a/test-data/unit/check-async-await.test b/test-data/unit/check-async-await.test index 979da62aca925..cac06c392fcd5 100644 --- a/test-data/unit/check-async-await.test +++ b/test-data/unit/check-async-await.test @@ -1080,3 +1080,24 @@ class Launcher(P): [builtins fixtures/async_await.pyi] [typing fixtures/typing-async.pyi] + +[case testAsyncAssignmentOuterContext] +# https://github.com/python/mypy/issues/15569 +from typing import Awaitable, Union, TypeVar + +# fixture does not provide wait_for, so we replicate it here +T = TypeVar('T') +async def wait_for(fut: Awaitable[T], timeout: Union[float, None]) -> T: ... + +class A: + i: Union[int, None] = None + +async def bar() -> tuple[int, int]: ... + +async def foo() -> int: + a = A() + a.i, _0 = await wait_for(bar(), timeout=5) + return a.i + +[builtins fixtures/async_await.pyi] +[typing fixtures/typing-async.pyi] diff --git a/test-data/unit/check-expressions.test b/test-data/unit/check-expressions.test index ea6eac9a39b3a..88bb0fab563ac 100644 --- a/test-data/unit/check-expressions.test +++ b/test-data/unit/check-expressions.test @@ -460,10 +460,6 @@ class A: def __contains__(self, x: 'A') -> str: pass [builtins fixtures/bool.pyi] -[case testInWithInvalidArgs] -a = 1 in ([1] + ['x']) # E: List item 0 has incompatible type "str"; expected "int" -[builtins fixtures/list.pyi] - [case testEq] a: A b: bool @@ -2492,3 +2488,443 @@ x + T # E: Unsupported left operand type for + ("int") T() # E: "TypeVar" not callable [builtins fixtures/tuple.pyi] [typing fixtures/typing-full.pyi] + + +[case testBinaryOperatorContext] +# https://github.com/python/mypy/issues/19304 +from typing import TypeVar, Generic, Iterable, Iterator, Union + +T = TypeVar("T") +S = TypeVar("S") +IntOrStr = TypeVar("IntOrStr", bound=Union[int, str]) + +class Vec(Generic[T]): + def __init__(self, iterable: Iterable[T], /) -> None: ... + def __iter__(self) -> Iterator[T]: yield from [] + def __add__(self, value: "Vec[S]", /) -> "Vec[Union[S, T]]": return Vec([]) + +def fmt(arg: Iterable[Union[int, str]]) -> None: ... +def first(arg: Iterable[IntOrStr]) -> IntOrStr: ... + +def test_fmt(l1: Vec[int], l2: Vec[int], /) -> None: + fmt(l1 + l2) + +def test_first(l1: Vec[int], l2: Vec[int], /) -> None: + first(l1 + l2) +[builtins fixtures/list.pyi] + +[case testBinaryOperatorContext2] +# https://github.com/python/mypy/issues/15087 +class A: ... +class B(A): ... + +def foo(x: list[A]) -> None: ... + +def test(l1: list[A], l2: list[B]) -> None: + foo(l1 + l2) +[builtins fixtures/list.pyi] + +[case testBinaryOperatorContext3] +# https://github.com/python/mypy/issues/19413 +from typing import Iterable, TypeVar + +T = TypeVar("T") + +def id_(iter1: Iterable[T]) -> Iterable[T]: ... + +def with_none(x: list[str]) -> None: + reveal_type(id_([None] + x)) # N: Revealed type is "typing.Iterable[Union[None, builtins.str]]" + +[builtins fixtures/list.pyi] + +[case testComparisonContext4] +# https://github.com/python/mypy/issues/12247 +def IgnoreColor(s: bytes) -> bytes: + return bytes(filter(lambda x: not (0x01 <= x <= 0x1F or x == 0x7F), s)) +[builtins fixtures/filter.pyi] + +[case testBinaryOperatorContextListConcat] +# https://github.com/python/mypy/issues/5874 +from typing import TypeVar, Union + +T = TypeVar('T') +S = TypeVar('S') + +def add(x: list[T], y: list[S]) -> list[Union[T, S]]: ... + +a: list[int] +b: list[str] + +e: list[Union[int, str]] = add(a, b) +[builtins fixtures/list.pyi] + +[case testBinaryOperatorContextListConcat2] +# https://github.com/python/mypy/issues/3933 +from typing import Union + +def test_left(mix: list[Union[int, str]], strings: list[str]) -> None: + mix = mix + strings + +def test_right(mix: list[Union[int, str]], strings: list[str]) -> None: + mix = strings + mix + +def test_inplace(mix: list[Union[int, str]], strings: list[str]) -> None: + mix += strings +[builtins fixtures/list.pyi] + +[case testBinaryOperatorContextDictUnion] +# flags: --python-version 3.10 +# https://github.com/python/mypy/issues/18236 +d1: dict[str, str] = {} +d2: dict[str, str | None] = {} + +d3 = d1 | d2 +reveal_type(d3) # N: Revealed type is "builtins.dict[builtins.str, Union[builtins.str, None]]" +d4: dict[str, str | None] = d3 +d5: dict[str, str | None] = d1 | d2 +[builtins fixtures/dict-full.pyi] + +[case testBinaryOperatorContext7] +# https://github.com/python/mypy/issues/5971 +from typing import TypeVar, Generic, Union, Any + +T = TypeVar('T') +T_co = TypeVar('T_co', covariant=True) +S = TypeVar('S') + +class A(Generic[T_co]): pass + +class B(A[T]): + def __add__(self, x: B[S]) -> B[Union[T, S]]: ... + +b1: B[Any] +b2: B[int] + +a: A[object] = b1 + b2 + +[case testBinaryOperatorContext8] +from typing import Iterable, Iterator, TypeVar, Generic, Union + +T = TypeVar("T") +S = TypeVar("S") + +class Vec(Generic[T]): + def getitem(self, i: int) -> T: ... # ensure invariance of T + def setitem(self, i: int, v: T) -> None: ... # ensure invariance of T + def __iter__(self) -> Iterator[T]: ... + def __add__(self, other: "Vec[S]") -> "Vec[Union[T, S]]": ... + +mix: Vec[Union[int, str]] +strings: Vec[str] +mix = mix + strings +mix = strings + mix +reveal_type(mix + strings) # N: Revealed type is "__main__.Vec[Union[builtins.int, builtins.str]]" +reveal_type(strings + mix) # N: Revealed type is "__main__.Vec[Union[builtins.str, builtins.int]]" +[builtins fixtures/list.pyi] + + +[case testBinaryOperatorContext9] +# https://github.com/python/mypy/issues/3933#issuecomment-2272804302 +from typing import Iterable, Iterator, TypeVar, Generic, Union + +T = TypeVar("T") +S = TypeVar("S") + +class Vec(Generic[T]): + def getitem(self, i: int) -> T: ... # ensure invariance of T + def setitem(self, i: int, v: T) -> None: ... # ensure invariance of T + def __iter__(self) -> Iterator[T]: ... + def __add__(self, other: "Vec[S]") -> "Vec[Union[T, S]]": ... + +def identity_on_iterable(arg: Iterable[T]) -> Iterable[T]: return arg +x: Vec[str] +y: Vec[None] +reveal_type( identity_on_iterable(y + x) ) # N: Revealed type is "typing.Iterable[Union[None, builtins.str]]" +reveal_type( identity_on_iterable(x + y) ) # N: Revealed type is "typing.Iterable[Union[builtins.str, None]]" + + +[case testBinaryOperatorContextIllegalAssignment] +from typing import Any, TypeVar, Union, Generic + +T = TypeVar("T") +S = TypeVar("S") +U = TypeVar("U") + +class Vec(Generic[T]): # invariant in T. + def get(self) -> T: ... + def set(self, arg: T) -> None: ... + def __add__(self, other: "Vec[S]") -> "Vec[Union[T, S]]": ... + +def id_(x: Vec[U]) -> Vec[U]: ... + +def test(a: Vec[int], b: Vec[str]) -> None: + y1: Vec[object] = a + b # E: Incompatible types in assignment (expression has type "Vec[Union[int, str]]", variable has type "Vec[object]") + y2: Vec[object] = b + a # E: Incompatible types in assignment (expression has type "Vec[Union[str, int]]", variable has type "Vec[object]") + y3: Vec[object] = id_(a + b) # E: Incompatible types in assignment (expression has type "Vec[Union[int, str]]", variable has type "Vec[object]") + y4: Vec[object] = id_(b + a) # E: Incompatible types in assignment (expression has type "Vec[Union[str, int]]", variable has type "Vec[object]") + y5: Vec[object] = id_(id_(a + b)) # E: Incompatible types in assignment (expression has type "Vec[Union[int, str]]", variable has type "Vec[object]") + y6: Vec[object] = id_(id_(b + a)) # E: Incompatible types in assignment (expression has type "Vec[Union[str, int]]", variable has type "Vec[object]") + + +[case testBinaryOperatorContextWithAny] +from typing import Any, TypeVar, Union, Generic + +T = TypeVar("T") +S = TypeVar("S") +U = TypeVar("U") + +class Vec(Generic[T]): # invariant in T. + def get(self) -> T: ... + def set(self, arg: T) -> None: ... + def __add__(self, other: "Vec[S]") -> "Vec[Union[T | S]]": ... + +def id_(x: Vec[U]) -> Vec[U]: ... + +def test_any_to_union(a: Vec[int], c: Vec[Any]) -> None: + # OK: Vec[Any | int] = Vec[str | int] for materialization Any=str + w1: Vec[Union[str, int]] = a + c + w2: Vec[Union[str, int]] = c + a + w3: Vec[Union[str, int]] = id_(a + c) + w4: Vec[Union[str, int]] = id_(c + a) + w5: Vec[Union[str, int]] = id_(id_(a + c)) + w6: Vec[Union[str, int]] = id_(id_(c + a)) + +def test_any_to_object(a: Vec[int], c: Vec[Any]) -> None: + # OK: Vec[Any | int] = Vec[object | int] = Vec[object] for materialization Any=object + w1: Vec[object] = a + c + w2: Vec[object] = c + a + w3: Vec[object] = id_(a + c) + w4: Vec[object] = id_(c + a) + w5: Vec[object] = id_(id_(a + c)) + w6: Vec[object] = id_(id_(c + a)) + +[case testBinaryOperatorContextWithAnySubclass] +from typing import Any, Generic, TypeVar, Union + +T = TypeVar("T") +S = TypeVar("S") +R = TypeVar("R", covariant=True) + +class A(Generic[R]): + # covariant parent class + + def getter(self) -> R: ... + +class B(A[T], Generic[T]): + # invariant child class + + def getter(self) -> T: ... + def setter(self, arg: T) -> None: ... + def __add__(self, x: "B[S]") -> "B[Union[T, S]]": ... + +def test_any(b1: B[Any], b2: B[int], b3: B[Union[Any, int]]) -> None: + a1: A[object] = b1 + b2 + a2: A[object] = b2 + b1 + a3: A[object] = b3 + +def test_object(b1: B[object], b2: B[int], b3: B[Union[Any, int]]) -> None: + a1: A[object] = b1 + b2 + a2: A[object] = b2 + b1 + a3: A[object] = b3 + +def test_any2(b1: B[Any], b2: B[int], b3: B[Union[Any, int]]) -> None: + # B[Any | int] = B[object] for materialization Any=object + # therefore assignable + a1: B[object] = b1 + b2 + a2: B[object] = b2 + b1 + a3: B[object] = b3 + +def test_object2(b1: B[object], b2: B[int], b3: B[Union[Any, int]]) -> None: + a1: B[object] = b1 + b2 + a2: B[object] = b2 + b1 + a3: B[object] = b3 + + +[case testDictExpressionInCallExpression] +# This checks that the Mapping Key TypeVar gets correctly resolved. +from typing import Iterable, Mapping, TypeVar, Union + +K = TypeVar("K") + +def f(objs: Union[Iterable[int], Mapping[K, int]]) -> K: ... + +r = f({"a": 1}) +reveal_type(r) # N: Revealed type is "builtins.str" + +[builtins fixtures/dict.pyi] + + +[case testDictExpressionInCallExpression2] +# This checks that outer constraints are correctly propagated to the inner expression. +from typing import Mapping, TypeVar + +T1 = TypeVar("T1") +T2 = TypeVar("T2") +T3 = TypeVar("T3") + +def identity1(arg: T1, /) -> T1: ... +def identity2(arg: T2, /) -> T2: ... +def identity3(arg: T3, /) -> T3: ... + +# test with multiple layers of identity functions +x1: Mapping[str, str] = identity1({}) +reveal_type(x1) # N: Revealed type is "typing.Mapping[builtins.str, builtins.str]" +x2: Mapping[str, str] = identity2(identity1({})) +reveal_type(x2) # N: Revealed type is "typing.Mapping[builtins.str, builtins.str]" +x3: Mapping[str, str] = identity3(identity2(identity1({}))) +reveal_type(x3) # N: Revealed type is "typing.Mapping[builtins.str, builtins.str]" + +# dict[str, str] is a subtype of Mapping[str, str], so mypy uses the more specific type. +y1: Mapping[str, str] +y1 = identity1({}) +reveal_type(y1) # N: Revealed type is "builtins.dict[builtins.str, builtins.str]" +y2: Mapping[str, str] +y2 = identity2(identity1({})) +reveal_type(y2) # N: Revealed type is "builtins.dict[builtins.str, builtins.str]" +y3: Mapping[str, str] +y3 = identity3(identity2(identity1({}))) +reveal_type(y3) # N: Revealed type is "builtins.dict[builtins.str, builtins.str]" +[builtins fixtures/dict.pyi] + + +[case testCallExpressionSecondArgumentSolvesTypeVar] +# This checks a case where the first argument in the inner expression does not solve the TypeVar, +# but the second or third argument does. +from typing import Iterable, TypeVar + +V = TypeVar("V") + +def f(x0: Iterable[V], x1: Iterable[V], x2: Iterable[V]) -> list[V]: ... + +reveal_type( f([], [], []) ) # N: Revealed type is "builtins.list[Never]" +reveal_type( f([], [], [1]) ) # N: Revealed type is "builtins.list[builtins.int]" +reveal_type( f([], [1], []) ) # N: Revealed type is "builtins.list[builtins.int]" +reveal_type( f([1], [], []) ) # N: Revealed type is "builtins.list[builtins.int]" +[builtins fixtures/list.pyi] + +[case testTupleExpressionInNestedCallExpressionCovariant] +from typing import TypeVar +from collections.abc import Sequence + +T = TypeVar("T") +def id_(x: Sequence[T]) -> tuple[T]: ... + +x0: tuple[object] = (1,) +x1: tuple[object] = id_((1,)) +x2: tuple[object] = id_(id_((1,))) +x3: tuple[object] = id_(id_(id_((1,)))) +x4: tuple[object] = id_(id_(id_(id_((1,))))) +[builtins fixtures/tuple.pyi] + +[case testTupleExpressionInNestedCallExpressionInvariant] +from typing import TypeVar +from collections.abc import Set as MutableSequence + +T = TypeVar("T") +def id_(x: MutableSequence[T]) -> tuple[T]: ... + +x0: tuple[object] = (1,) +x1: tuple[object] = id_((1,)) +x2: tuple[object] = id_(id_((1,))) +x3: tuple[object] = id_(id_(id_((1,)))) +x4: tuple[object] = id_(id_(id_(id_((1,))))) +[builtins fixtures/tuple.pyi] + +[case testListExpressionInNestedCallExpressionCovariant] +from typing import TypeVar +from collections.abc import Sequence + +T = TypeVar("T") +def id_(x: Sequence[T]) -> list[T]: ... + +x0: list[object] = [1] +x1: list[object] = id_([1]) +x2: list[object] = id_(id_([1])) +x3: list[object] = id_(id_(id_([1]))) +x4: list[object] = id_(id_(id_(id_([1])))) +[builtins fixtures/list.pyi] + +[case testListExpressionInNestedCallExpressionInvariant] +from typing import TypeVar +from collections.abc import Set as MutableSequence + +T = TypeVar("T") +def id_(x: MutableSequence[T]) -> list[T]: ... + +x0: list[object] = [1] +x1: list[object] = id_([1]) +x2: list[object] = id_(id_([1])) +x3: list[object] = id_(id_(id_([1]))) +x4: list[object] = id_(id_(id_(id_([1])))) +[builtins fixtures/list.pyi] + +[case testSetExpressionInNestedCallExpressionCovariant] +from typing import TypeVar +from collections.abc import Set as AbstractSet + +T = TypeVar("T") +def id_(x: AbstractSet[T]) -> set[T]: ... + +x0: set[object] = {1} +x1: set[object] = id_({1}) +x2: set[object] = id_(id_({1})) +x3: set[object] = id_(id_(id_({1}))) +x4: set[object] = id_(id_(id_(id_({1})))) +[builtins fixtures/set.pyi] + +[case testSetExpressionInNestedCallExpressionInvariant] +from typing import TypeVar +from collections.abc import MutableSet + +T = TypeVar("T") +def id_(x: MutableSet[T]) -> set[T]: ... + +x0: set[object] = {1} +x1: set[object] = id_({1}) +x2: set[object] = id_(id_({1})) +x3: set[object] = id_(id_(id_({1}))) +x4: set[object] = id_(id_(id_(id_({1})))) +[builtins fixtures/set.pyi] + +[case testDictExpressionInNestedCallExpressionCovariant] +from typing import TypeVar +from collections.abc import Mapping + +K = TypeVar("K") +V = TypeVar("V") +def id_(x: Mapping[K, V]) -> dict[K, V]: ... + +x0: dict[object, object] = {"a": 1} +x1: dict[object, object] = id_({"a": 1}) +x2: dict[object, object] = id_(id_({"a": 1})) +x3: dict[object, object] = id_(id_(id_({"a": 1}))) +x4: dict[object, object] = id_(id_(id_(id_({"a": 1})))) +[builtins fixtures/dict.pyi] + +[case testDictExpressionInNestedCallExpressionInvariant] +from typing import TypeVar +from collections.abc import MutableMapping + +K = TypeVar("K") +V = TypeVar("V") +def id_(x: MutableMapping[K, V]) -> dict[K, V]: ... + +x0: dict[object, object] = {"a": 1} +x1: dict[object, object] = id_({"a": 1}) +x2: dict[object, object] = id_(id_({"a": 1})) +x3: dict[object, object] = id_(id_(id_({"a": 1}))) +x4: dict[object, object] = id_(id_(id_(id_({"a": 1})))) +[builtins fixtures/dict.pyi] + +[case testLambdaExpressionInNestedCallExpression] +from typing import TypeVar, Callable + +T = TypeVar("T") +S = TypeVar("S") +def id_(x: Callable[[T], S]) -> Callable[[T], S]: ... + +x0: Callable[[object], object] = lambda x: x +x1: Callable[[object], object] = id_(lambda x: x) +x2: Callable[[object], object] = id_(id_(lambda x: x)) +x3: Callable[[object], object] = id_(id_(id_(lambda x: x))) +x4: Callable[[object], object] = id_(id_(id_(id_(lambda x: x)))) diff --git a/test-data/unit/check-functions.test b/test-data/unit/check-functions.test index 7fa34a398ea05..daa4db062c15f 100644 --- a/test-data/unit/check-functions.test +++ b/test-data/unit/check-functions.test @@ -3384,7 +3384,7 @@ def f(x: T, y: S) -> Union[T, S]: ... def g(x: T, y: S) -> Union[T, S]: ... x = [f, g] -reveal_type(x) # N: Revealed type is "builtins.list[def [T, S] (x: T`4, y: S`5) -> Union[T`4, S`5]]" +reveal_type(x) # N: Revealed type is "builtins.list[def [T, S] (x: T`6, y: S`7) -> Union[T`6, S`7]]" [builtins fixtures/list.pyi] [case testTypeVariableClashErrorMessage] diff --git a/test-data/unit/check-generics.test b/test-data/unit/check-generics.test index 3b535ab4a1c04..6184fd154c088 100644 --- a/test-data/unit/check-generics.test +++ b/test-data/unit/check-generics.test @@ -2985,7 +2985,7 @@ def lift(f: F[T]) -> F[Optional[T]]: ... def g(x: T) -> T: return x -reveal_type(lift(g)) # N: Revealed type is "def [T] (Union[T`1, None]) -> Union[T`1, None]" +reveal_type(lift(g)) # N: Revealed type is "__main__.F[Union[T`-1, None]]" [builtins fixtures/list.pyi] [case testInferenceAgainstGenericSplitOrder] diff --git a/test-data/unit/check-incremental.test b/test-data/unit/check-incremental.test index 8e05f922be174..d9d78715b396a 100644 --- a/test-data/unit/check-incremental.test +++ b/test-data/unit/check-incremental.test @@ -2577,6 +2577,13 @@ C(1)[0] [builtins fixtures/list.pyi] [out] +[case testSerializeRecursiveAlias] +from typing import Callable, Union + +Node = Union[str, int, Callable[[], "Node"]] +n: Node +[out] + [case testSerializeRecursiveAliases1] from typing import Type, Callable, Union diff --git a/test-data/unit/check-inference-context.test b/test-data/unit/check-inference-context.test index a41ee5f59670e..e83d8241bf6cc 100644 --- a/test-data/unit/check-inference-context.test +++ b/test-data/unit/check-inference-context.test @@ -1371,6 +1371,22 @@ x: Tuple[str, ...] = f(tuple) [builtins fixtures/tuple.pyi] [out] +[case testTypedDictWideContext] +from typing_extensions import TypedDict +from typing import TypeVar, Generic + +T = TypeVar('T') + +class A: ... +class B(A): ... + +class OverridesItem(TypedDict, Generic[T]): + tp: type[T] + +d1: dict[str, dict[str, type[A]]] = {"foo": {"bar": B}} +d2: dict[str, OverridesItem[A]] = {"foo": OverridesItem(tp=B)} +[builtins fixtures/dict.pyi] + [case testUseCovariantGenericOuterContextUserDefined] from typing import TypeVar, Callable, Generic diff --git a/test-data/unit/check-inference.test b/test-data/unit/check-inference.test index 63278d6c4547a..58860f471787e 100644 --- a/test-data/unit/check-inference.test +++ b/test-data/unit/check-inference.test @@ -2427,20 +2427,20 @@ a2.foo2() [case testUnificationEmptyListLeft] def f(): pass a = [] if f() else [0] -a() # E: "list[int]" not callable +reveal_type(a) # N: Revealed type is "builtins.list[builtins.int]" [builtins fixtures/list.pyi] [case testUnificationEmptyListRight] def f(): pass a = [0] if f() else [] -a() # E: "list[int]" not callable +reveal_type(a) # N: Revealed type is "builtins.list[builtins.int]" [builtins fixtures/list.pyi] [case testUnificationEmptyListLeftInContext] from typing import List def f(): pass a = [] if f() else [0] # type: list[int] -a() # E: "list[int]" not callable +reveal_type(a) # N: Revealed type is "builtins.list[builtins.int]" [builtins fixtures/list.pyi] [case testUnificationEmptyListRightInContext] @@ -2448,37 +2448,37 @@ a() # E: "list[int]" not callable from typing import List def f(): pass a = [0] if f() else [] # type: list[int] -a() # E: "list[int]" not callable +reveal_type(a) # N: Revealed type is "builtins.list[builtins.int]" [builtins fixtures/list.pyi] [case testUnificationEmptySetLeft] def f(): pass a = set() if f() else {0} -a() # E: "set[int]" not callable +reveal_type(a) # N: Revealed type is "builtins.set[builtins.int]" [builtins fixtures/set.pyi] [case testUnificationEmptyDictLeft] def f(): pass a = {} if f() else {0: 0} -a() # E: "dict[int, int]" not callable +reveal_type(a) # N: Revealed type is "builtins.dict[builtins.int, builtins.int]" [builtins fixtures/dict.pyi] [case testUnificationEmptyDictRight] def f(): pass a = {0: 0} if f() else {} -a() # E: "dict[int, int]" not callable +reveal_type(a) # N: Revealed type is "builtins.dict[builtins.int, builtins.int]" [builtins fixtures/dict.pyi] [case testUnificationDictWithEmptyListLeft] def f(): pass a = {0: []} if f() else {0: [0]} -a() # E: "dict[int, list[int]]" not callable +reveal_type(a) # N: Revealed type is "builtins.dict[builtins.int, builtins.list[builtins.int]]" [builtins fixtures/dict.pyi] [case testUnificationDictWithEmptyListRight] def f(): pass a = {0: [0]} if f() else {0: []} -a() # E: "dict[int, list[int]]" not callable +reveal_type(a) # N: Revealed type is "builtins.dict[builtins.int, builtins.list[builtins.int]]" [builtins fixtures/dict.pyi] [case testMisguidedSetItem] @@ -3388,7 +3388,7 @@ from typing import Any, Union, Iterable y: Union[Iterable[Any], Any] x: Union[Iterable[Any], Any] x = [y] -reveal_type(x) # N: Revealed type is "builtins.list[Any]" +reveal_type(x) # N: Revealed type is "builtins.list[Union[typing.Iterable[Any], Any]]" [builtins fixtures/list.pyi] [case testInferredTypeIsSimpleNestedListLoop] @@ -3410,7 +3410,7 @@ def test(seq: List[Union[Iterable, Any]]) -> None: for k in seq: if bool(): k = [k] - reveal_type(k) # N: Revealed type is "builtins.list[Any]" + reveal_type(k) # N: Revealed type is "builtins.list[Union[typing.Iterable[Any], Any]]" [builtins fixtures/list.pyi] [case testErasedTypeRuntimeCoverage] diff --git a/test-data/unit/check-literal.test b/test-data/unit/check-literal.test index 3c9290b8dbbba..512b82a7db49f 100644 --- a/test-data/unit/check-literal.test +++ b/test-data/unit/check-literal.test @@ -2946,6 +2946,37 @@ reveal_type(C().collection) # N: Revealed type is "builtins.list[Literal['word' reveal_type(C().word) # N: Revealed type is "Literal['word']" [builtins fixtures/tuple.pyi] +[case testLiteralInDictExpression] +from typing import Mapping, Literal + +x: Mapping[str, Literal["sum", "mean", "max", "min"]] = {"x": "sum"} + +[builtins fixtures/dict.pyi] + +[case testLiteralTernaryExpression] +def test(b: bool) -> None: + l = 1 if b else "a" + reveal_type(l) # N: Revealed type is "Union[builtins.int, builtins.str]" +[builtins fixtures/tuple.pyi] + +[case testLiteralTernaryListExpression] +def test(b: bool) -> None: + l = [1] if b else ["a"] + reveal_type(l) # N: Revealed type is "Union[builtins.list[builtins.int], builtins.list[builtins.str]]" +[builtins fixtures/list.pyi] + +[case testLiteralTernarySetExpression] +def test(b: bool) -> None: + s = {1} if b else {"a"} + reveal_type(s) # N: Revealed type is "Union[builtins.set[builtins.int], builtins.set[builtins.str]]" +[builtins fixtures/set.pyi] + +[case testLiteralTernaryDictExpression] +def test(b: bool) -> None: + d = {1:1} if "" else {"a": "a"} + reveal_type(d) # N: Revealed type is "Union[builtins.dict[builtins.int, builtins.int], builtins.dict[builtins.str, builtins.str]]" +[builtins fixtures/dict.pyi] + [case testLiteralTernaryUnionNarrowing] from typing import Literal, Optional diff --git a/test-data/unit/check-optional.test b/test-data/unit/check-optional.test index 679906b0e00ed..87f870e1e4359 100644 --- a/test-data/unit/check-optional.test +++ b/test-data/unit/check-optional.test @@ -427,7 +427,8 @@ reveal_type(l) # N: Revealed type is "builtins.list[typing.Generator[builtins.s [builtins fixtures/list.pyi] [case testNoneListTernary] -x = [None] if "" else [1] # E: List item 0 has incompatible type "int"; expected "None" +x = [None] if "" else [1] +reveal_type(x) # N: Revealed type is "Union[builtins.list[None], builtins.list[builtins.int]]" [builtins fixtures/list.pyi] [case testListIncompatibleErrorMessage] diff --git a/test-data/unit/check-python313.test b/test-data/unit/check-python313.test index b46ae0fecfc42..57970c84d69b7 100644 --- a/test-data/unit/check-python313.test +++ b/test-data/unit/check-python313.test @@ -290,3 +290,20 @@ reveal_type(A1().x) # N: Revealed type is "builtins.int" reveal_type(A2().x) # N: Revealed type is "builtins.int" reveal_type(A3().x) # N: Revealed type is "builtins.int" [builtins fixtures/tuple.pyi] + +[case testTernaryOperatorWithDefault] +# https://github.com/python/mypy/issues/18817 + +class Ok[T, E = None]: + def __init__(self, value: T) -> None: + self._value = value + +class Err[E, T = None]: + def __init__(self, value: E) -> None: + self._value = value + +type Result[T, E] = Ok[T, E] | Err[E, T] + +class Bar[U]: + def foo(data: U, cond: bool) -> Result[U, str]: + return Ok(data) if cond else Err("Error") diff --git a/test-data/unit/check-recursive-types.test b/test-data/unit/check-recursive-types.test index 86e9f02b52636..706202b1a109e 100644 --- a/test-data/unit/check-recursive-types.test +++ b/test-data/unit/check-recursive-types.test @@ -285,21 +285,33 @@ if isinstance(b[0], Sequence): [case testRecursiveAliasWithRecursiveInstance] from typing import Sequence, Union, TypeVar -class A: ... T = TypeVar("T") Nested = Sequence[Union[T, Nested[T]]] +def join(a: T, b: T) -> T: ... + +class A: ... class B(Sequence[B]): ... a: Nested[A] aa: Nested[A] b: B + a = b # OK +reveal_type(a) # N: Revealed type is "__main__.B" + a = [[b]] # OK +reveal_type(a) # N: Revealed type is "builtins.list[builtins.list[__main__.B]]" + b = aa # E: Incompatible types in assignment (expression has type "Nested[A]", variable has type "B") +reveal_type(b) # N: Revealed type is "__main__.B" + +reveal_type(join(a, b)) # N: Revealed type is "typing.Sequence[typing.Sequence[__main__.B]]" +reveal_type(join(b, a)) # N: Revealed type is "typing.Sequence[typing.Sequence[__main__.B]]" + +def test(a: Nested[A], b: B) -> None: + reveal_type(join(a, b)) # N: Revealed type is "typing.Sequence[Union[__main__.A, typing.Sequence[Union[__main__.A, ...]]]]" + reveal_type(join(b, a)) # N: Revealed type is "typing.Sequence[Union[__main__.A, typing.Sequence[Union[__main__.A, ...]]]]" -def join(a: T, b: T) -> T: ... -reveal_type(join(a, b)) # N: Revealed type is "typing.Sequence[Union[__main__.A, typing.Sequence[Union[__main__.A, ...]]]]" -reveal_type(join(b, a)) # N: Revealed type is "typing.Sequence[Union[__main__.A, typing.Sequence[Union[__main__.A, ...]]]]" [builtins fixtures/isinstancelist.pyi] [case testRecursiveAliasWithRecursiveInstanceInference] @@ -605,7 +617,7 @@ class NT(NamedTuple, Generic[T]): class A: ... class B(A): ... -nti: NT[int] = NT(key=0, value=NT(key=1, value=A())) # E: Argument "value" to "NT" has incompatible type "A"; expected "Union[int, NT[int]]" +nti: NT[int] = NT(key=0, value=NT(key=1, value=A())) # E: Argument "value" to "NT" has incompatible type "NT[A]"; expected "Union[int, NT[int]]" reveal_type(nti) # N: Revealed type is "tuple[builtins.int, Union[builtins.int, ...], fallback=__main__.NT[builtins.int]]" nta: NT[A] diff --git a/test-data/unit/check-statements.test b/test-data/unit/check-statements.test index 9ab68b32472d1..6adb535f2d599 100644 --- a/test-data/unit/check-statements.test +++ b/test-data/unit/check-statements.test @@ -2361,3 +2361,25 @@ describe(CWrong()) # E: Argument 1 to "describe" has incompatible type "CWrong" # N: "CWrong.__call__" has type "Callable[[Arg(int, 'x')], None]" describe(f) [builtins fixtures/isinstancelist.pyi] + + +[case testAssignmentOuterContext] +# https://github.com/python/mypy/issues/16310 +from typing import TypeVar, Generic, Union + +T = TypeVar('T') +U = TypeVar('U') + +class Test(Generic[T, U]): pass + +def test(gen: Test[T, U]) -> tuple[T, Union[U, str]]: ... + +def call() -> Test[str, int]: ... + +def caller_bad() -> None: + x, y = test(call()) + +def caller_ok() -> None: + res = test(call()) + x, y = res +[builtins fixtures/tuple.pyi] diff --git a/test-data/unit/check-varargs.test b/test-data/unit/check-varargs.test index 680021a166f2b..def03f5f3ec11 100644 --- a/test-data/unit/check-varargs.test +++ b/test-data/unit/check-varargs.test @@ -629,9 +629,9 @@ from typing import TypeVar T = TypeVar('T') def f(*args: T) -> T: ... -reveal_type(f(*(1, None))) # N: Revealed type is "Union[Literal[1]?, None]" -reveal_type(f(1, *(None, 1))) # N: Revealed type is "Union[Literal[1]?, None]" -reveal_type(f(1, *(1, None))) # N: Revealed type is "Union[Literal[1]?, None]" +reveal_type(f(*(1, None))) # N: Revealed type is "Union[builtins.int, None]" +reveal_type(f(1, *(None, 1))) # N: Revealed type is "Union[builtins.int, None]" +reveal_type(f(1, *(1, None))) # N: Revealed type is "Union[builtins.int, None]" [builtins fixtures/tuple.pyi] diff --git a/test-data/unit/fixtures/async_await.pyi b/test-data/unit/fixtures/async_await.pyi index 96ade881111b3..2f81d85af8501 100644 --- a/test-data/unit/fixtures/async_await.pyi +++ b/test-data/unit/fixtures/async_await.pyi @@ -14,6 +14,7 @@ class function: pass class int: pass class float: pass class str: pass +class bytes: pass class bool(int): pass class dict(typing.Generic[T, U]): pass class set(typing.Generic[T]): pass diff --git a/test-data/unit/fixtures/filter.pyi b/test-data/unit/fixtures/filter.pyi new file mode 100644 index 0000000000000..ccbbde2f5a938 --- /dev/null +++ b/test-data/unit/fixtures/filter.pyi @@ -0,0 +1,31 @@ +# Minimal set of builtins required to work with Enums + +from typing import Generic, Iterable, Self, Callable, TypeVar, Any, overload +from typing_extensions import TypeIs, TypeGuard + +_T = TypeVar('_T') +_S = TypeVar('_S') + +class filter(Generic[_T]): + @overload + def __new__(cls, function: None, iterable: Iterable[_T | None], /) -> Self: ... + @overload + def __new__(cls, function: Callable[[_S], TypeGuard[_T]], iterable: Iterable[_S], /) -> Self: ... + @overload + def __new__(cls, function: Callable[[_S], TypeIs[_T]], iterable: Iterable[_S], /) -> Self: ... + @overload + def __new__(cls, function: Callable[[_T], Any], iterable: Iterable[_T], /) -> Self: ... + def __iter__(self) -> Self: ... + def __next__(self) -> _T: ... + +class function: pass +class type: pass +class int: pass +class tuple: pass +class bool(int): pass +class float: pass +class str: pass +class bool: pass +class ellipsis: pass +class dict: pass +class bytes: pass diff --git a/test-data/unit/fixtures/list.pyi b/test-data/unit/fixtures/list.pyi index 3dcdf18b2faa3..7ec16a0d20490 100644 --- a/test-data/unit/fixtures/list.pyi +++ b/test-data/unit/fixtures/list.pyi @@ -1,8 +1,9 @@ # Builtins stub used in list-related test cases. -from typing import TypeVar, Generic, Iterable, Iterator, Sequence, overload +from typing import TypeVar, Generic, Iterable, Iterator, Sequence, Union, overload T = TypeVar('T') +_S = TypeVar('_S') class object: def __init__(self) -> None: pass @@ -19,7 +20,7 @@ class list(Sequence[T]): def __iter__(self) -> Iterator[T]: pass def __len__(self) -> int: pass def __contains__(self, item: object) -> bool: pass - def __add__(self, x: list[T]) -> list[T]: pass + def __add__(self, x: list[_S]) -> list[Union[T, _S]]: pass def __mul__(self, x: int) -> list[T]: pass def __getitem__(self, x: int) -> T: pass def __setitem__(self, x: int, v: T) -> None: pass diff --git a/test-data/unit/fixtures/set.pyi b/test-data/unit/fixtures/set.pyi index f757679a95f4e..3968fe8a0375e 100644 --- a/test-data/unit/fixtures/set.pyi +++ b/test-data/unit/fixtures/set.pyi @@ -10,6 +10,7 @@ class object: class type: pass class tuple(Generic[T]): pass +class list(Generic[T]): pass class function: pass class int: pass diff --git a/test-data/unit/pythoneval.test b/test-data/unit/pythoneval.test index 93b67bfa813a8..710c823b7b7d3 100644 --- a/test-data/unit/pythoneval.test +++ b/test-data/unit/pythoneval.test @@ -622,12 +622,14 @@ _program.py:4: error: Argument 1 to "map" has incompatible type "Callable[[VarAr import typing x = range(3) a = list(map(str, x)) +reveal_type(a) a + 1 [out] -_testMapStr.py:4: error: No overload variant of "__add__" of "list" matches argument type "int" -_testMapStr.py:4: note: Possible overload variants: -_testMapStr.py:4: note: def __add__(self, list[str], /) -> list[str] -_testMapStr.py:4: note: def [_S] __add__(self, list[_S], /) -> list[Union[_S, str]] +_testMapStr.py:4: note: Revealed type is "builtins.list[builtins.str]" +_testMapStr.py:5: error: No overload variant of "__add__" of "list" matches argument type "int" +_testMapStr.py:5: note: Possible overload variants: +_testMapStr.py:5: note: def __add__(self, list[str], /) -> list[str] +_testMapStr.py:5: note: def [_S] __add__(self, list[_S], /) -> list[Union[_S, str]] [case testRelativeImport] import typing