From 63e2872abbdfe66f41dfadfaada7beb9f146d451 Mon Sep 17 00:00:00 2001 From: Hugo Polloli Date: Tue, 23 Dec 2025 10:58:22 +0100 Subject: [PATCH] [ty] Add support for dict literals and dict() calls as default values for parameters with TypedDict types --- .../resources/mdtest/function/parameters.md | 31 +++++ crates/ty_python_semantic/src/types/infer.rs | 3 + .../src/types/infer/builder.rs | 126 ++++++++++++++---- 3 files changed, 132 insertions(+), 28 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/function/parameters.md b/crates/ty_python_semantic/resources/mdtest/function/parameters.md index 8611354f582e9..06ee9b4e4586f 100644 --- a/crates/ty_python_semantic/resources/mdtest/function/parameters.md +++ b/crates/ty_python_semantic/resources/mdtest/function/parameters.md @@ -64,6 +64,37 @@ def g(x: Any = "foo"): reveal_type(x) # revealed: Any ``` +## TypedDict defaults use annotation context + +```py +from typing import TypedDict + +class Foo(TypedDict): + x: int + +def x(a: Foo = {"x": 42}): ... +def y(a: Foo = dict(x=42)): ... +``` + +## TypedDict defaults still validate keys and value types + +```py +from typing import TypedDict + +class Foo(TypedDict): + x: int + y: int + +# error: [missing-typed-dict-key] +def missing_key(a: Foo = {"x": 42}): ... + +# error: [invalid-argument-type] +def wrong_type(a: Foo = {"x": "s", "y": 1}): ... + +# error: [invalid-key] +def extra_key(a: Foo = {"x": 1, "y": 2, "z": 3}): ... +``` + ## Stub functions ```toml diff --git a/crates/ty_python_semantic/src/types/infer.rs b/crates/ty_python_semantic/src/types/infer.rs index 32c68ab1e39c3..8e14313fbcdae 100644 --- a/crates/ty_python_semantic/src/types/infer.rs +++ b/crates/ty_python_semantic/src/types/infer.rs @@ -713,6 +713,9 @@ struct DefinitionInferenceExtra<'db> { /// String annotations found in this region string_annotations: FxHashSet, + /// Functions called while inferring this definition. + called_functions: Box<[FunctionType<'db>]>, + /// The fallback type for missing expressions/bindings/declarations or recursive type inference. cycle_recovery: Option>, diff --git a/crates/ty_python_semantic/src/types/infer/builder.rs b/crates/ty_python_semantic/src/types/infer/builder.rs index 6abd37aa127a3..a825d3112093b 100644 --- a/crates/ty_python_semantic/src/types/infer/builder.rs +++ b/crates/ty_python_semantic/src/types/infer/builder.rs @@ -398,6 +398,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } if let Some(extra) = &inference.extra { + self.called_functions + .extend(extra.called_functions.iter().copied()); self.extend_cycle_recovery(extra.cycle_recovery); self.context.extend(&extra.diagnostics); self.deferred @@ -2711,23 +2713,17 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { decorator_types_and_nodes.push((decorator_type, decorator)); } - // In stub files, default values may reference names that are defined later in the file. - let in_stub = self.in_stub(); - let previous_deferred_state = std::mem::replace(&mut self.deferred_state, in_stub.into()); - for default in parameters + let has_defaults = parameters .iter_non_variadic_params() - .filter_map(|param| param.default.as_deref()) - { - self.infer_expression(default, TypeContext::default()); - } - self.deferred_state = previous_deferred_state; + .any(|param| param.default.is_some()); // If there are type params, parameters and returns are evaluated in that scope. Otherwise, // we always defer the inference of the parameters and returns. That ensures that we do not // add any spurious salsa cycles when applying decorators below. (Applying a decorator // requires getting the signature of this function definition, which in turn requires - // (lazily) inferring the parameter and return types.) - if type_params.is_none() { + // (lazily) inferring the parameter and return types.) If defaults exist, we also defer so + // they can be inferred once with type context in the enclosing scope. + if type_params.is_none() || has_defaults { self.deferred.insert(definition, self.multi_inference_state); } @@ -2918,12 +2914,12 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { /// default value) both belong to outer scopes. (The default value always belongs to the outer /// scope in which the function is defined, the annotation belongs either to the outer scope, /// or maybe to an intervening type-params scope, if it's a generic function.) So we don't use - /// `self.infer_expression` or store any expression types here, we just use `expression_ty` to - /// get the types of the expressions from their respective scopes. + /// `self.infer_expression` or store any expression types here, we just query for the types of + /// the expressions from their respective scopes. /// - /// It is safe (non-cycle-causing) to use `expression_ty` here, because an outer scope can't - /// depend on a definition from an inner scope, so we shouldn't be in-process of inferring the - /// outer scope here. + /// It is safe (non-cycle-causing) to query the annotation type via `file_expression_type` + /// here, because an outer scope can't depend on a definition from an inner scope, so we + /// shouldn't be in-process of inferring the outer scope here. fn infer_parameter_definition( &mut self, parameter_with_default: &'ast ast::ParameterWithDefault, @@ -2935,13 +2931,21 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { range: _, node_index: _, } = parameter_with_default; - let default_ty = default - .as_ref() - .map(|default| self.file_expression_type(default)); + let default_expr = default.as_ref(); if let Some(annotation) = parameter.annotation.as_ref() { let declared_ty = self.file_expression_type(annotation); - if let Some(default_ty) = default_ty { + if let Some(default_expr) = default_expr { + let default_expr = default_expr.as_ref(); + let default_ty = self.file_expression_type(default_expr); + + // Avoid duplicate diagnostics: invalid TypedDict literals already emit specific errors. + let suppress_invalid_default = diagnostic::is_invalid_typed_dict_literal( + self.db(), + declared_ty, + default_expr.into(), + ); if !default_ty.is_assignable_to(self.db(), declared_ty) + && !suppress_invalid_default && !((self.in_stub() || self.in_function_overload_or_abstractmethod() || self.scope().scope(self.db()).in_type_checking_block() @@ -2971,7 +2975,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { &DeclaredAndInferredType::are_the_same_type(declared_ty), ); } else { - let ty = if let Some(default_ty) = default_ty { + let ty = if let Some(default_expr) = default_expr { + let default_ty = self.file_expression_type(default_expr); UnionType::from_elements(self.db(), [Type::unknown(), default_ty]) } else if let Some(ty) = self.special_first_method_parameter_type(parameter) { ty @@ -3389,12 +3394,72 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } self.context.set_in_no_type_check(prev_in_no_type_check); + let has_type_params = function.type_params.is_some(); + let has_defaults = function + .parameters + .iter_non_variadic_params() + .any(|param| param.default.is_some()); + let previous_typevar_binding_context = self.typevar_binding_context.replace(definition); - self.infer_return_type_annotation( - function.returns.as_deref(), - self.defer_annotations().into(), - ); - self.infer_parameters(function.parameters.as_ref()); + + if !has_type_params { + self.infer_return_type_annotation( + function.returns.as_deref(), + self.defer_annotations().into(), + ); + self.infer_parameters(function.parameters.as_ref()); + } + + if has_defaults { + // In stub files, default values may reference names that are defined later in the file. + let in_stub = self.in_stub(); + let previous_deferred_state = + std::mem::replace(&mut self.deferred_state, in_stub.into()); + + // For generic functions, only defaults are inferred here; annotation types come from + // the type-params scope. + if has_type_params { + let type_params_scope = self + .index + .node_scope(NodeWithScopeRef::FunctionTypeParameters(function)) + .to_scope_id(self.db(), self.file()); + let type_params_inference = + infer_scope_types(self.db(), type_params_scope, TypeContext::default()); + + for param_with_default in function.parameters.iter_non_variadic_params() { + let Some(default) = param_with_default.default.as_deref() else { + continue; + }; + let tcx = param_with_default + .parameter + .annotation + .as_deref() + .map(|annotation| { + TypeContext::new(Some( + type_params_inference.expression_type(annotation), + )) + }) + .unwrap_or_else(TypeContext::default); + self.infer_expression(default, tcx); + } + } else { + for param_with_default in function.parameters.iter_non_variadic_params() { + let Some(default) = param_with_default.default.as_deref() else { + continue; + }; + let tcx = param_with_default + .parameter + .annotation + .as_deref() + .map(|annotation| TypeContext::new(Some(self.expression_type(annotation)))) + .unwrap_or_else(TypeContext::default); + self.infer_expression(default, tcx); + } + } + + self.deferred_state = previous_deferred_state; + } + self.typevar_binding_context = previous_typevar_binding_context; } @@ -15299,6 +15364,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { deferred, cycle_recovery, undecorated_type, + called_functions, // builder only state dataclass_field_specifiers: _, all_definitely_bound: _, @@ -15306,7 +15372,6 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { deferred_state: _, multi_inference_state: _, inner_expression_inference_state: _, - called_functions: _, index: _, region: _, return_types_and_ranges: _, @@ -15319,10 +15384,15 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { || !string_annotations.is_empty() || cycle_recovery.is_some() || undecorated_type.is_some() - || !deferred.is_empty()) + || !deferred.is_empty() + || !called_functions.is_empty()) .then(|| { Box::new(DefinitionInferenceExtra { string_annotations, + called_functions: called_functions + .into_iter() + .collect::>() + .into_boxed_slice(), cycle_recovery, deferred: deferred.into_boxed_slice(), diagnostics,