diff --git a/crates/ty_python_semantic/resources/corpus/cyclic_lambdas.py b/crates/ty_python_semantic/resources/corpus/cyclic_lambdas.py new file mode 100644 index 0000000000000..e7868d8dda2cf --- /dev/null +++ b/crates/ty_python_semantic/resources/corpus/cyclic_lambdas.py @@ -0,0 +1,25 @@ +# This test would previously panic with: `infer_definition_types(Id(1406)): execute: too many cycle iterations`. + +lambda: name_4 + +@lambda: name_5 +class name_1: ... + +name_2 = [lambda: name_4, name_1] + +if name_2: + @(*name_2,) + class name_3: ... + assert unique_name_19 + +@lambda: name_3 +class name_4[*name_2](0, name_1=name_3): ... + +try: + [name_5, name_4] = *name_4, = name_4 +except* 0: + ... +else: + async def name_4(): ... + +for name_3 in name_4: ... diff --git a/crates/ty_python_semantic/resources/mdtest/bidirectional.md b/crates/ty_python_semantic/resources/mdtest/bidirectional.md index d62af32105ef0..db8fe891d7305 100644 --- a/crates/ty_python_semantic/resources/mdtest/bidirectional.md +++ b/crates/ty_python_semantic/resources/mdtest/bidirectional.md @@ -397,6 +397,59 @@ def _(flag: bool): reveal_type(x2) # revealed: list[int | None] ``` +## Lambda expressions + +If a lambda expression is annotated as a `Callable` type, the body of the lambda is inferred with +the annotated return type as type context, and the annotated parameter types are respected: + +```py +from typing import Callable, TypedDict + +class Bar(TypedDict): + bar: int + +f1 = lambda x: {"bar": 1} +reveal_type(f1) # revealed: (x) -> dict[str, int] + +f2: Callable[[int], Bar] = lambda x: {"bar": 1} +reveal_type(f2) # revealed: (x: int) -> Bar + +# error: [missing-typed-dict-key] "Missing required key 'bar' in TypedDict `Bar` constructor" +# error: [invalid-assignment] "Object of type `(x: int) -> dict[Unknown, Unknown]` is not assignable to `(int, /) -> Bar`" +f3: Callable[[int], Bar] = lambda x: {} +reveal_type(f3) # revealed: (int, /) -> Bar + +# TODO: This should reveal `str`. +f4: Callable[[str], str] = lambda x: reveal_type(x) # revealed: Unknown +reveal_type(f4) # revealed: (x: str) -> Unknown + +# TODO: This should not error once we support `Unpack`. +# error: [invalid-assignment] +f5: Callable[[*tuple[int, ...]], None] = lambda x, y, z: None +reveal_type(f5) # revealed: (tuple[int, ...], /) -> None + +f6: Callable[[int, str], None] = lambda *args: None +reveal_type(f6) # revealed: (*args) -> None + +# N.B. `Callable` annotations only support positional parameters. +# error: [invalid-assignment] +f7: Callable[[int], None] = lambda *, x=1: None +reveal_type(f7) # revealed: (int, /) -> None + +# TODO: This should reveal `(*args: int, *, x=1) -> None` once we support `Unpack`. +f8: Callable[[*tuple[int, ...], int], None] = lambda *args, x=1: None +reveal_type(f8) # revealed: (*args, *, x=1) -> None +``` + +We do not currently account for type annotations present later in the scope: + +```py +f9 = lambda: [1] +# TODO: This should not error. +_: list[int | str] = f9() # error: [invalid-assignment] +reveal_type(f9) # revealed: () -> list[int] +``` + ## Dunder Calls The key and value parameters types are used as type context for `__setitem__` dunder calls: diff --git a/crates/ty_python_semantic/resources/mdtest/cycle.md b/crates/ty_python_semantic/resources/mdtest/cycle.md index 4a7fc249c0345..c29d61747a283 100644 --- a/crates/ty_python_semantic/resources/mdtest/cycle.md +++ b/crates/ty_python_semantic/resources/mdtest/cycle.md @@ -128,16 +128,16 @@ class C: self.c = lambda positional_only=self.c, /: positional_only self.d = lambda *, kw_only=self.d: kw_only - # revealed: (positional=...) -> Unknown + # revealed: (positional=...) -> Unknown | ((positional=...) -> Divergent) | ((positional=...) -> Unknown) | ((positional=...) -> Divergent) reveal_type(self.a) - # revealed: (*, kw_only=...) -> Unknown + # revealed: (*, kw_only=...) -> Unknown | ((*, kw_only=...) -> Divergent) | ((*, kw_only=...) -> Unknown) | ((*, kw_only=...) -> Divergent) reveal_type(self.b) - # revealed: (positional_only=..., /) -> Unknown + # revealed: (positional_only=..., /) -> Unknown | ((positional_only=..., /) -> Divergent) | ((positional_only=..., /) -> Unknown) | ((positional_only=..., /) -> Divergent) reveal_type(self.c) - # revealed: (*, kw_only=...) -> Unknown + # revealed: (*, kw_only=...) -> Unknown | ((*, kw_only=...) -> Divergent) | ((*, kw_only=...) -> Unknown) | ((*, kw_only=...) -> Divergent) reveal_type(self.d) ``` diff --git a/crates/ty_python_semantic/resources/mdtest/expression/lambda.md b/crates/ty_python_semantic/resources/mdtest/expression/lambda.md index ad0a0afc5595b..0993896dc3d8e 100644 --- a/crates/ty_python_semantic/resources/mdtest/expression/lambda.md +++ b/crates/ty_python_semantic/resources/mdtest/expression/lambda.md @@ -5,7 +5,7 @@ `lambda` expressions can be defined without any parameters. ```py -reveal_type(lambda: 1) # revealed: () -> Unknown +reveal_type(lambda: 1) # revealed: () -> Literal[1] # error: [unresolved-reference] reveal_type(lambda: a) # revealed: () -> Unknown @@ -24,7 +24,7 @@ reveal_type(lambda a, b: a + b) # revealed: (a, b) -> Unknown But, it can have default values: ```py -reveal_type(lambda a=1: a) # revealed: (a=1) -> Unknown +reveal_type(lambda a=1: a) # revealed: (a=1) -> Unknown | Literal[1] reveal_type(lambda a, b=2: a) # revealed: (a, b=2) -> Unknown ``` @@ -37,25 +37,25 @@ reveal_type(lambda a, b, /, c: c) # revealed: (a, b, /, c) -> Unknown And, keyword-only parameters: ```py -reveal_type(lambda a, *, b=2, c: b) # revealed: (a, *, b=2, c) -> Unknown +reveal_type(lambda a, *, b=2, c: b) # revealed: (a, *, b=2, c) -> Unknown | Literal[2] ``` And, variadic parameter: ```py -reveal_type(lambda *args: args) # revealed: (*args) -> Unknown +reveal_type(lambda *args: args) # revealed: (*args) -> tuple[Unknown, ...] ``` And, keyword-varidic parameter: ```py -reveal_type(lambda **kwargs: kwargs) # revealed: (**kwargs) -> Unknown +reveal_type(lambda **kwargs: kwargs) # revealed: (**kwargs) -> dict[str, Unknown] ``` Mixing all of them together: ```py -# revealed: (a, b, /, c=True, *args, *, d="default", e=5, **kwargs) -> Unknown +# revealed: (a, b, /, c=True, *args, *, d="default", e=5, **kwargs) -> None reveal_type(lambda a, b, /, c=True, *args, d="default", e=5, **kwargs: None) ``` @@ -94,7 +94,7 @@ Here, a `lambda` expression is used as the default value for a parameter in anot expression. ```py -reveal_type(lambda a=lambda x, y: 0: 2) # revealed: (a=...) -> Unknown +reveal_type(lambda a=lambda x, y: 0: 2) # revealed: (a=...) -> Literal[2] ``` ## Assignment @@ -114,6 +114,9 @@ a4: Callable[[int, int], None] = lambda *args: None a5: Callable[[], None] = lambda x: None # error: [invalid-assignment] a6: Callable[[int], None] = lambda: None + +# error: [invalid-assignment] +a7: Callable[[], str] = lambda: 1 ``` ## Function-like behavior of lambdas diff --git a/crates/ty_python_semantic/src/semantic_index/scope.rs b/crates/ty_python_semantic/src/semantic_index/scope.rs index df10f6e7dd660..d5ffb43f0a0e8 100644 --- a/crates/ty_python_semantic/src/semantic_index/scope.rs +++ b/crates/ty_python_semantic/src/semantic_index/scope.rs @@ -38,7 +38,8 @@ impl<'db> ScopeId<'db> { pub(crate) fn accepts_type_context(self, db: &dyn Db) -> bool { matches!( self.node(db), - NodeWithScopeKind::ListComprehension(_) + NodeWithScopeKind::Lambda(_) + | NodeWithScopeKind::ListComprehension(_) | NodeWithScopeKind::SetComprehension(_) | NodeWithScopeKind::DictComprehension(_) ) diff --git a/crates/ty_python_semantic/src/types/infer/builder.rs b/crates/ty_python_semantic/src/types/infer/builder.rs index 520b2e8f7841e..a23e181e8fd8c 100644 --- a/crates/ty_python_semantic/src/types/infer/builder.rs +++ b/crates/ty_python_semantic/src/types/infer/builder.rs @@ -579,7 +579,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { NodeWithScopeKind::Function(function) => { self.infer_function_body(function.node(self.module())); } - NodeWithScopeKind::Lambda(lambda) => self.infer_lambda_body(lambda.node(self.module())), + NodeWithScopeKind::Lambda(lambda) => { + self.infer_lambda_body(lambda.node(self.module()), tcx); + } NodeWithScopeKind::Class(class) => self.infer_class_body(class.node(self.module())), NodeWithScopeKind::ClassTypeParameters(class) => { self.infer_class_type_params(class.node(self.module())); @@ -5468,7 +5470,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { ast::Expr::Subscript(subscript) => self.infer_subscript_expression(subscript), ast::Expr::Slice(slice) => self.infer_slice_expression(slice), ast::Expr::If(if_expression) => self.infer_if_expression(if_expression, tcx), - ast::Expr::Lambda(lambda_expression) => self.infer_lambda_expression(lambda_expression), + ast::Expr::Lambda(lambda_expression) => { + self.infer_lambda_expression(lambda_expression, tcx) + } ast::Expr::Call(call_expression) => self.infer_call_expression(call_expression, tcx), ast::Expr::Starred(starred) => self.infer_starred_expression(starred, tcx), ast::Expr::Yield(yield_expression) => self.infer_yield_expression(yield_expression), @@ -6724,11 +6728,15 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } } - fn infer_lambda_body(&mut self, lambda_expression: &ast::ExprLambda) { - self.infer_expression(&lambda_expression.body, TypeContext::default()); + fn infer_lambda_body(&mut self, lambda_expression: &ast::ExprLambda, tcx: TypeContext<'db>) { + self.infer_expression(&lambda_expression.body, tcx); } - fn infer_lambda_expression(&mut self, lambda_expression: &ast::ExprLambda) -> Type<'db> { + fn infer_lambda_expression( + &mut self, + lambda_expression: &ast::ExprLambda, + tcx: TypeContext<'db>, + ) -> Type<'db> { let ast::ExprLambda { range: _, node_index: _, @@ -6740,27 +6748,64 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { let in_stub = self.in_stub(); let previous_deferred_state = std::mem::replace(&mut self.deferred_state, in_stub.into()); + let callable_tcx = if let Some(tcx) = tcx.annotation + // TODO: We could perform multi-inference here if there are multiple `Callable` annotations + // in the union. + && let Some(callable) = tcx + .filter_union(self.db(), Type::is_callable_type) + .as_callable() + { + let [signature] = callable.signatures(self.db()).overloads.as_slice() else { + panic!("`Callable` type annotations cannot be overloaded"); + }; + + Some(signature) + } else { + None + }; + + // Extract the annotated parameter types. + // + // Note that `Callable` annotations are only valid for positional parameters. + let mut parameter_types = match callable_tcx { + None => [].iter(), + Some(signature) => signature.parameters().into_iter(), + } + .map(Parameter::annotated_type); + let parameters = if let Some(parameters) = parameters { let positional_only = parameters .posonlyargs .iter() .map(|param| { - Parameter::positional_only(Some(param.name().id.clone())) + let parameter = Parameter::positional_only(Some(param.name().id.clone())) .with_optional_default_type(param.default().map(|default_expr| { self.infer_expression(default_expr, TypeContext::default()) .replace_parameter_defaults(self.db()) - })) + })); + + if let Some(annotated_type) = parameter_types.next() { + parameter.with_annotated_type(annotated_type) + } else { + parameter + } }) .collect::>(); let positional_or_keyword = parameters .args .iter() .map(|param| { - Parameter::positional_or_keyword(param.name().id.clone()) + let parameter = Parameter::positional_or_keyword(param.name().id.clone()) .with_optional_default_type(param.default().map(|default_expr| { self.infer_expression(default_expr, TypeContext::default()) .replace_parameter_defaults(self.db()) - })) + })); + + if let Some(annotated_type) = parameter_types.next() { + parameter.with_annotated_type(annotated_type) + } else { + parameter + } }) .collect::>(); let variadic = parameters @@ -6784,25 +6829,48 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { .as_ref() .map(|param| Parameter::keyword_variadic(param.name().id.clone())); - Parameters::new( - self.db(), - positional_only - .into_iter() - .chain(positional_or_keyword) - .chain(variadic) - .chain(keyword_only) - .chain(keyword_variadic), - ) + let parameters = positional_only + .into_iter() + .chain(positional_or_keyword) + .chain(variadic) + .chain(keyword_only) + .chain(keyword_variadic); + + Parameters::new(self.db(), parameters) } else { Parameters::empty() }; self.deferred_state = previous_deferred_state; - // TODO: Useful inference of a lambda's return type will require a different approach, - // which does the inference of the body expression based on arguments at each call site, - // rather than eagerly computing a return type without knowing the argument types. - Type::function_like_callable(self.db(), Signature::new(parameters, Type::unknown())) + let Some(scope_id) = self + .index + .try_node_scope(NodeWithScopeRef::Lambda(lambda_expression)) + else { + return Type::unknown(); + }; + + let scope = scope_id.to_scope_id(self.db(), self.file()); + + // If we have a direct `Callable` type context, we can infer the body with the annotated + // return type as type context. + let return_tcx = if let Some(signature) = callable_tcx { + match signature.return_ty { + Type::Dynamic(DynamicType::Unknown) => TypeContext::new(None), + _ => TypeContext::new(Some(signature.return_ty)), + } + } else { + // TODO: Useful inference of a lambda's return type will require a different approach, + // which does the inference of the body expression based on arguments at each call site, + // rather than eagerly computing a return type without knowing the argument types. + TypeContext::new(None) + }; + + let inference = infer_scope_types(self.db(), scope, return_tcx); + self.extend_scope(inference); + + let return_ty = inference.expression_type(lambda_expression.body.as_ref()); + Type::function_like_callable(self.db(), Signature::new(parameters, return_ty)) } /// Attempt to narrow a splatted dictionary argument based on the narrowed types of individual diff --git a/crates/ty_python_semantic/src/types/infer/builder/type_expression.rs b/crates/ty_python_semantic/src/types/infer/builder/type_expression.rs index 92a8f5b4ce495..7691b685b5ca8 100644 --- a/crates/ty_python_semantic/src/types/infer/builder/type_expression.rs +++ b/crates/ty_python_semantic/src/types/infer/builder/type_expression.rs @@ -461,7 +461,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> { ast::Expr::Lambda(lambda_expression) => { if !self.deferred_state.in_string_annotation() { - self.infer_lambda_expression(lambda_expression); + self.infer_lambda_expression(lambda_expression, TypeContext::default()); } self.report_invalid_type_expression( expression,