Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions crates/ty_python_semantic/resources/corpus/cyclic_lambdas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# This test would previously panic with: `infer_definition_types(Id(1406)): execute: too many cycle iterations`.

lambda: name_4

@lambda: name_5
class name_1: ...

name_2 = [lambda: name_4, name_1]

if name_2:
@(*name_2,)
class name_3: ...
assert unique_name_19

@lambda: name_3
class name_4[*name_2](0, name_1=name_3): ...

try:
[name_5, name_4] = *name_4, = name_4
except* 0:
...
else:
async def name_4(): ...

for name_3 in name_4: ...
53 changes: 53 additions & 0 deletions crates/ty_python_semantic/resources/mdtest/bidirectional.md
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,59 @@ def _(flag: bool):
reveal_type(x2) # revealed: list[int | None]
```

## Lambda expressions

If a lambda expression is annotated as a `Callable` type, the body of the lambda is inferred with
the annotated return type as type context, and the annotated parameter types are respected:

```py
from typing import Callable, TypedDict

class Bar(TypedDict):
bar: int

f1 = lambda x: {"bar": 1}
reveal_type(f1) # revealed: (x) -> dict[str, int]

f2: Callable[[int], Bar] = lambda x: {"bar": 1}
reveal_type(f2) # revealed: (x: int) -> Bar

# error: [missing-typed-dict-key] "Missing required key 'bar' in TypedDict `Bar` constructor"
# error: [invalid-assignment] "Object of type `(x: int) -> dict[Unknown, Unknown]` is not assignable to `(int, /) -> Bar`"
f3: Callable[[int], Bar] = lambda x: {}
reveal_type(f3) # revealed: (int, /) -> Bar

# TODO: This should reveal `str`.
f4: Callable[[str], str] = lambda x: reveal_type(x) # revealed: Unknown
reveal_type(f4) # revealed: (x: str) -> Unknown

# TODO: This should not error once we support `Unpack`.
# error: [invalid-assignment]
f5: Callable[[*tuple[int, ...]], None] = lambda x, y, z: None
reveal_type(f5) # revealed: (tuple[int, ...], /) -> None

f6: Callable[[int, str], None] = lambda *args: None
reveal_type(f6) # revealed: (*args) -> None

# N.B. `Callable` annotations only support positional parameters.
# error: [invalid-assignment]
f7: Callable[[int], None] = lambda *, x=1: None
reveal_type(f7) # revealed: (int, /) -> None

# TODO: This should reveal `(*args: int, *, x=1) -> None` once we support `Unpack`.
f8: Callable[[*tuple[int, ...], int], None] = lambda *args, x=1: None
reveal_type(f8) # revealed: (*args, *, x=1) -> None
```

We do not currently account for type annotations present later in the scope:

```py
f9 = lambda: [1]
# TODO: This should not error.
_: list[int | str] = f9() # error: [invalid-assignment]
reveal_type(f9) # revealed: () -> list[int]
```

## Dunder Calls

The key and value parameters types are used as type context for `__setitem__` dunder calls:
Expand Down
8 changes: 4 additions & 4 deletions crates/ty_python_semantic/resources/mdtest/cycle.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,16 +128,16 @@ class C:
self.c = lambda positional_only=self.c, /: positional_only
self.d = lambda *, kw_only=self.d: kw_only

# revealed: (positional=...) -> Unknown
# revealed: (positional=...) -> Unknown | ((positional=...) -> Divergent) | ((positional=...) -> Unknown) | ((positional=...) -> Divergent)
reveal_type(self.a)

# revealed: (*, kw_only=...) -> Unknown
# revealed: (*, kw_only=...) -> Unknown | ((*, kw_only=...) -> Divergent) | ((*, kw_only=...) -> Unknown) | ((*, kw_only=...) -> Divergent)
reveal_type(self.b)

# revealed: (positional_only=..., /) -> Unknown
# revealed: (positional_only=..., /) -> Unknown | ((positional_only=..., /) -> Divergent) | ((positional_only=..., /) -> Unknown) | ((positional_only=..., /) -> Divergent)
reveal_type(self.c)

# revealed: (*, kw_only=...) -> Unknown
# revealed: (*, kw_only=...) -> Unknown | ((*, kw_only=...) -> Divergent) | ((*, kw_only=...) -> Unknown) | ((*, kw_only=...) -> Divergent)
reveal_type(self.d)
```

Expand Down
17 changes: 10 additions & 7 deletions crates/ty_python_semantic/resources/mdtest/expression/lambda.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
`lambda` expressions can be defined without any parameters.

```py
reveal_type(lambda: 1) # revealed: () -> Unknown
reveal_type(lambda: 1) # revealed: () -> Literal[1]

# error: [unresolved-reference]
reveal_type(lambda: a) # revealed: () -> Unknown
Expand All @@ -24,7 +24,7 @@ reveal_type(lambda a, b: a + b) # revealed: (a, b) -> Unknown
But, it can have default values:

```py
reveal_type(lambda a=1: a) # revealed: (a=1) -> Unknown
reveal_type(lambda a=1: a) # revealed: (a=1) -> Unknown | Literal[1]
reveal_type(lambda a, b=2: a) # revealed: (a, b=2) -> Unknown
```

Expand All @@ -37,25 +37,25 @@ reveal_type(lambda a, b, /, c: c) # revealed: (a, b, /, c) -> Unknown
And, keyword-only parameters:

```py
reveal_type(lambda a, *, b=2, c: b) # revealed: (a, *, b=2, c) -> Unknown
reveal_type(lambda a, *, b=2, c: b) # revealed: (a, *, b=2, c) -> Unknown | Literal[2]
```

And, variadic parameter:

```py
reveal_type(lambda *args: args) # revealed: (*args) -> Unknown
reveal_type(lambda *args: args) # revealed: (*args) -> tuple[Unknown, ...]
```

And, keyword-varidic parameter:

```py
reveal_type(lambda **kwargs: kwargs) # revealed: (**kwargs) -> Unknown
reveal_type(lambda **kwargs: kwargs) # revealed: (**kwargs) -> dict[str, Unknown]
```

Mixing all of them together:

```py
# revealed: (a, b, /, c=True, *args, *, d="default", e=5, **kwargs) -> Unknown
# revealed: (a, b, /, c=True, *args, *, d="default", e=5, **kwargs) -> None
reveal_type(lambda a, b, /, c=True, *args, d="default", e=5, **kwargs: None)
```

Expand Down Expand Up @@ -94,7 +94,7 @@ Here, a `lambda` expression is used as the default value for a parameter in anot
expression.

```py
reveal_type(lambda a=lambda x, y: 0: 2) # revealed: (a=...) -> Unknown
reveal_type(lambda a=lambda x, y: 0: 2) # revealed: (a=...) -> Literal[2]
```

## Assignment
Expand All @@ -114,6 +114,9 @@ a4: Callable[[int, int], None] = lambda *args: None
a5: Callable[[], None] = lambda x: None
# error: [invalid-assignment]
a6: Callable[[int], None] = lambda: None

# error: [invalid-assignment]
a7: Callable[[], str] = lambda: 1
```

## Function-like behavior of lambdas
Expand Down
3 changes: 2 additions & 1 deletion crates/ty_python_semantic/src/semantic_index/scope.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ impl<'db> ScopeId<'db> {
pub(crate) fn accepts_type_context(self, db: &dyn Db) -> bool {
matches!(
self.node(db),
NodeWithScopeKind::ListComprehension(_)
NodeWithScopeKind::Lambda(_)
| NodeWithScopeKind::ListComprehension(_)
| NodeWithScopeKind::SetComprehension(_)
| NodeWithScopeKind::DictComprehension(_)
)
Expand Down
112 changes: 90 additions & 22 deletions crates/ty_python_semantic/src/types/infer/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -579,7 +579,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
NodeWithScopeKind::Function(function) => {
self.infer_function_body(function.node(self.module()));
}
NodeWithScopeKind::Lambda(lambda) => self.infer_lambda_body(lambda.node(self.module())),
NodeWithScopeKind::Lambda(lambda) => {
self.infer_lambda_body(lambda.node(self.module()), tcx);
}
NodeWithScopeKind::Class(class) => self.infer_class_body(class.node(self.module())),
NodeWithScopeKind::ClassTypeParameters(class) => {
self.infer_class_type_params(class.node(self.module()));
Expand Down Expand Up @@ -5468,7 +5470,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
ast::Expr::Subscript(subscript) => self.infer_subscript_expression(subscript),
ast::Expr::Slice(slice) => self.infer_slice_expression(slice),
ast::Expr::If(if_expression) => self.infer_if_expression(if_expression, tcx),
ast::Expr::Lambda(lambda_expression) => self.infer_lambda_expression(lambda_expression),
ast::Expr::Lambda(lambda_expression) => {
self.infer_lambda_expression(lambda_expression, tcx)
}
ast::Expr::Call(call_expression) => self.infer_call_expression(call_expression, tcx),
ast::Expr::Starred(starred) => self.infer_starred_expression(starred, tcx),
ast::Expr::Yield(yield_expression) => self.infer_yield_expression(yield_expression),
Expand Down Expand Up @@ -6724,11 +6728,15 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
}
}

fn infer_lambda_body(&mut self, lambda_expression: &ast::ExprLambda) {
self.infer_expression(&lambda_expression.body, TypeContext::default());
fn infer_lambda_body(&mut self, lambda_expression: &ast::ExprLambda, tcx: TypeContext<'db>) {
self.infer_expression(&lambda_expression.body, tcx);
}

fn infer_lambda_expression(&mut self, lambda_expression: &ast::ExprLambda) -> Type<'db> {
fn infer_lambda_expression(
&mut self,
lambda_expression: &ast::ExprLambda,
tcx: TypeContext<'db>,
) -> Type<'db> {
let ast::ExprLambda {
range: _,
node_index: _,
Expand All @@ -6740,27 +6748,64 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
let in_stub = self.in_stub();
let previous_deferred_state = std::mem::replace(&mut self.deferred_state, in_stub.into());

let callable_tcx = if let Some(tcx) = tcx.annotation
// TODO: We could perform multi-inference here if there are multiple `Callable` annotations
// in the union.
&& let Some(callable) = tcx
.filter_union(self.db(), Type::is_callable_type)
.as_callable()
{
let [signature] = callable.signatures(self.db()).overloads.as_slice() else {
panic!("`Callable` type annotations cannot be overloaded");
};

Some(signature)
} else {
None
};

// Extract the annotated parameter types.
//
// Note that `Callable` annotations are only valid for positional parameters.
let mut parameter_types = match callable_tcx {
None => [].iter(),
Some(signature) => signature.parameters().into_iter(),
}
.map(Parameter::annotated_type);

let parameters = if let Some(parameters) = parameters {
let positional_only = parameters
.posonlyargs
.iter()
.map(|param| {
Parameter::positional_only(Some(param.name().id.clone()))
let parameter = Parameter::positional_only(Some(param.name().id.clone()))
.with_optional_default_type(param.default().map(|default_expr| {
self.infer_expression(default_expr, TypeContext::default())
.replace_parameter_defaults(self.db())
}))
}));

if let Some(annotated_type) = parameter_types.next() {
parameter.with_annotated_type(annotated_type)
} else {
parameter
}
})
.collect::<Vec<_>>();
let positional_or_keyword = parameters
.args
.iter()
.map(|param| {
Parameter::positional_or_keyword(param.name().id.clone())
let parameter = Parameter::positional_or_keyword(param.name().id.clone())
.with_optional_default_type(param.default().map(|default_expr| {
self.infer_expression(default_expr, TypeContext::default())
.replace_parameter_defaults(self.db())
}))
}));

if let Some(annotated_type) = parameter_types.next() {
parameter.with_annotated_type(annotated_type)
} else {
parameter
}
})
.collect::<Vec<_>>();
let variadic = parameters
Expand All @@ -6784,25 +6829,48 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
.as_ref()
.map(|param| Parameter::keyword_variadic(param.name().id.clone()));

Parameters::new(
self.db(),
positional_only
.into_iter()
.chain(positional_or_keyword)
.chain(variadic)
.chain(keyword_only)
.chain(keyword_variadic),
)
let parameters = positional_only
.into_iter()
.chain(positional_or_keyword)
.chain(variadic)
.chain(keyword_only)
.chain(keyword_variadic);

Parameters::new(self.db(), parameters)
} else {
Parameters::empty()
};

self.deferred_state = previous_deferred_state;

// TODO: Useful inference of a lambda's return type will require a different approach,
// which does the inference of the body expression based on arguments at each call site,
// rather than eagerly computing a return type without knowing the argument types.
Type::function_like_callable(self.db(), Signature::new(parameters, Type::unknown()))
let Some(scope_id) = self
.index
.try_node_scope(NodeWithScopeRef::Lambda(lambda_expression))
else {
return Type::unknown();
};

let scope = scope_id.to_scope_id(self.db(), self.file());

// If we have a direct `Callable` type context, we can infer the body with the annotated
// return type as type context.
let return_tcx = if let Some(signature) = callable_tcx {
match signature.return_ty {
Type::Dynamic(DynamicType::Unknown) => TypeContext::new(None),
_ => TypeContext::new(Some(signature.return_ty)),
}
} else {
// TODO: Useful inference of a lambda's return type will require a different approach,
// which does the inference of the body expression based on arguments at each call site,
// rather than eagerly computing a return type without knowing the argument types.
TypeContext::new(None)
};

let inference = infer_scope_types(self.db(), scope, return_tcx);
self.extend_scope(inference);

let return_ty = inference.expression_type(lambda_expression.body.as_ref());
Type::function_like_callable(self.db(), Signature::new(parameters, return_ty))
}

/// Attempt to narrow a splatted dictionary argument based on the narrowed types of individual
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> {

ast::Expr::Lambda(lambda_expression) => {
if !self.deferred_state.in_string_annotation() {
self.infer_lambda_expression(lambda_expression);
self.infer_lambda_expression(lambda_expression, TypeContext::default());
}
self.report_invalid_type_expression(
expression,
Expand Down
Loading