Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,37 @@ def g(x: Any = "foo"):
reveal_type(x) # revealed: Any
```

## TypedDict defaults use annotation context

```py
from typing import TypedDict

class Foo(TypedDict):
x: int

def x(a: Foo = {"x": 42}): ...
def y(a: Foo = dict(x=42)): ...
```

## TypedDict defaults still validate keys and value types

```py
from typing import TypedDict

class Foo(TypedDict):
x: int
y: int

# error: [missing-typed-dict-key]
def missing_key(a: Foo = {"x": 42}): ...

# error: [invalid-argument-type]
def wrong_type(a: Foo = {"x": "s", "y": 1}): ...

# error: [invalid-key]
def extra_key(a: Foo = {"x": 1, "y": 2, "z": 3}): ...
```

## Stub functions

```toml
Expand Down
3 changes: 3 additions & 0 deletions crates/ty_python_semantic/src/types/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -713,6 +713,9 @@ struct DefinitionInferenceExtra<'db> {
/// String annotations found in this region
string_annotations: FxHashSet<ExpressionNodeKey>,

/// Functions called while inferring this definition.
called_functions: Box<[FunctionType<'db>]>,

/// The fallback type for missing expressions/bindings/declarations or recursive type inference.
cycle_recovery: Option<Type<'db>>,

Expand Down
126 changes: 98 additions & 28 deletions crates/ty_python_semantic/src/types/infer/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
}

if let Some(extra) = &inference.extra {
self.called_functions
.extend(extra.called_functions.iter().copied());
self.extend_cycle_recovery(extra.cycle_recovery);
self.context.extend(&extra.diagnostics);
self.deferred
Expand Down Expand Up @@ -2711,23 +2713,17 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
decorator_types_and_nodes.push((decorator_type, decorator));
}

// In stub files, default values may reference names that are defined later in the file.
let in_stub = self.in_stub();
let previous_deferred_state = std::mem::replace(&mut self.deferred_state, in_stub.into());
for default in parameters
let has_defaults = parameters
.iter_non_variadic_params()
.filter_map(|param| param.default.as_deref())
{
self.infer_expression(default, TypeContext::default());
}
self.deferred_state = previous_deferred_state;
.any(|param| param.default.is_some());

// If there are type params, parameters and returns are evaluated in that scope. Otherwise,
// we always defer the inference of the parameters and returns. That ensures that we do not
// add any spurious salsa cycles when applying decorators below. (Applying a decorator
// requires getting the signature of this function definition, which in turn requires
// (lazily) inferring the parameter and return types.)
if type_params.is_none() {
// (lazily) inferring the parameter and return types.) If defaults exist, we also defer so
// they can be inferred once with type context in the enclosing scope.
if type_params.is_none() || has_defaults {
self.deferred.insert(definition, self.multi_inference_state);
}

Expand Down Expand Up @@ -2918,12 +2914,12 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
/// default value) both belong to outer scopes. (The default value always belongs to the outer
/// scope in which the function is defined, the annotation belongs either to the outer scope,
/// or maybe to an intervening type-params scope, if it's a generic function.) So we don't use
/// `self.infer_expression` or store any expression types here, we just use `expression_ty` to
/// get the types of the expressions from their respective scopes.
/// `self.infer_expression` or store any expression types here, we just query for the types of
/// the expressions from their respective scopes.
///
/// It is safe (non-cycle-causing) to use `expression_ty` here, because an outer scope can't
/// depend on a definition from an inner scope, so we shouldn't be in-process of inferring the
/// outer scope here.
/// It is safe (non-cycle-causing) to query the annotation type via `file_expression_type`
/// here, because an outer scope can't depend on a definition from an inner scope, so we
/// shouldn't be in-process of inferring the outer scope here.
fn infer_parameter_definition(
&mut self,
parameter_with_default: &'ast ast::ParameterWithDefault,
Expand All @@ -2935,13 +2931,21 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
range: _,
node_index: _,
} = parameter_with_default;
let default_ty = default
.as_ref()
.map(|default| self.file_expression_type(default));
let default_expr = default.as_ref();
if let Some(annotation) = parameter.annotation.as_ref() {
let declared_ty = self.file_expression_type(annotation);
Copy link
Member

@ibraheemdev ibraheemdev Jan 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We infer the parameter default as part of the outer scope here. Inferring them multiple times means that the original inferred type will still be displayed, e.g., by the IDE on hover. We should try to infer the value directly with type context in its scope.

I think what this requires is creating a deferred function scope if the function has default value expressions, even if it has type parameters, and then infer default values in infer_function_deferred with the parameter annotations. Note that if there are type parameters, the only thing that should be inferred in infer_function_deferred is the default values, and you should call infer_scope_types on the type-params scope to get the parameter annotation types to use as type context.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi thank you so much, I completely missed that, I've started working on this and have set this PR as draft for now

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, things are much better now
I hope what I did is what you had in mind! Defaults are now inferred once in the deferred pass with the parameter annotation as type context (for generic functions, we pull annotations from the type‑params scope).

An added benefit of this method is that we replace generic invalid-parameter-default errors by more explicit TypedDict specific errors, I updated the mdtest to reflect that :)

if let Some(default_ty) = default_ty {
if let Some(default_expr) = default_expr {
let default_expr = default_expr.as_ref();
let default_ty = self.file_expression_type(default_expr);

// Avoid duplicate diagnostics: invalid TypedDict literals already emit specific errors.
let suppress_invalid_default = diagnostic::is_invalid_typed_dict_literal(
self.db(),
declared_ty,
default_expr.into(),
);
if !default_ty.is_assignable_to(self.db(), declared_ty)
&& !suppress_invalid_default
&& !((self.in_stub()
|| self.in_function_overload_or_abstractmethod()
|| self.scope().scope(self.db()).in_type_checking_block()
Expand Down Expand Up @@ -2971,7 +2975,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
&DeclaredAndInferredType::are_the_same_type(declared_ty),
);
} else {
let ty = if let Some(default_ty) = default_ty {
let ty = if let Some(default_expr) = default_expr {
let default_ty = self.file_expression_type(default_expr);
UnionType::from_elements(self.db(), [Type::unknown(), default_ty])
} else if let Some(ty) = self.special_first_method_parameter_type(parameter) {
ty
Expand Down Expand Up @@ -3389,12 +3394,72 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
}
self.context.set_in_no_type_check(prev_in_no_type_check);

let has_type_params = function.type_params.is_some();
let has_defaults = function
.parameters
.iter_non_variadic_params()
.any(|param| param.default.is_some());

let previous_typevar_binding_context = self.typevar_binding_context.replace(definition);
self.infer_return_type_annotation(
function.returns.as_deref(),
self.defer_annotations().into(),
);
self.infer_parameters(function.parameters.as_ref());

if !has_type_params {
self.infer_return_type_annotation(
function.returns.as_deref(),
self.defer_annotations().into(),
);
self.infer_parameters(function.parameters.as_ref());
}

if has_defaults {
// In stub files, default values may reference names that are defined later in the file.
let in_stub = self.in_stub();
let previous_deferred_state =
std::mem::replace(&mut self.deferred_state, in_stub.into());

// For generic functions, only defaults are inferred here; annotation types come from
// the type-params scope.
if has_type_params {
let type_params_scope = self
.index
.node_scope(NodeWithScopeRef::FunctionTypeParameters(function))
.to_scope_id(self.db(), self.file());
let type_params_inference =
infer_scope_types(self.db(), type_params_scope, TypeContext::default());

for param_with_default in function.parameters.iter_non_variadic_params() {
let Some(default) = param_with_default.default.as_deref() else {
continue;
};
let tcx = param_with_default
.parameter
.annotation
.as_deref()
.map(|annotation| {
TypeContext::new(Some(
type_params_inference.expression_type(annotation),
))
})
.unwrap_or_else(TypeContext::default);
self.infer_expression(default, tcx);
}
} else {
for param_with_default in function.parameters.iter_non_variadic_params() {
let Some(default) = param_with_default.default.as_deref() else {
continue;
};
let tcx = param_with_default
.parameter
.annotation
.as_deref()
.map(|annotation| TypeContext::new(Some(self.expression_type(annotation))))
.unwrap_or_else(TypeContext::default);
self.infer_expression(default, tcx);
}
}

self.deferred_state = previous_deferred_state;
}

self.typevar_binding_context = previous_typevar_binding_context;
}

Expand Down Expand Up @@ -15299,14 +15364,14 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
deferred,
cycle_recovery,
undecorated_type,
called_functions,
// builder only state
dataclass_field_specifiers: _,
all_definitely_bound: _,
typevar_binding_context: _,
deferred_state: _,
multi_inference_state: _,
inner_expression_inference_state: _,
called_functions: _,
index: _,
region: _,
return_types_and_ranges: _,
Expand All @@ -15319,10 +15384,15 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|| !string_annotations.is_empty()
|| cycle_recovery.is_some()
|| undecorated_type.is_some()
|| !deferred.is_empty())
|| !deferred.is_empty()
|| !called_functions.is_empty())
.then(|| {
Box::new(DefinitionInferenceExtra {
string_annotations,
called_functions: called_functions
.into_iter()
.collect::<Vec<_>>()
.into_boxed_slice(),
cycle_recovery,
deferred: deferred.into_boxed_slice(),
diagnostics,
Expand Down