From 5844c0103d30de4f8890414b38ce1a8d92d6f2d2 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Tue, 23 Sep 2025 14:29:17 +0900 Subject: [PATCH 01/26] [ty] propagate the annotated return type of functions to the inference of expressions in return statements --- .../resources/mdtest/bidirectional.md | 78 +++++++++++++++++++ .../ty_python_semantic/src/types/function.rs | 72 +++++++++++++++-- crates/ty_python_semantic/src/types/infer.rs | 24 ++++++ .../src/types/infer/builder.rs | 17 +++- .../src/types/signatures.rs | 69 +++++++++++----- 5 files changed, 232 insertions(+), 28 deletions(-) create mode 100644 crates/ty_python_semantic/resources/mdtest/bidirectional.md diff --git a/crates/ty_python_semantic/resources/mdtest/bidirectional.md b/crates/ty_python_semantic/resources/mdtest/bidirectional.md new file mode 100644 index 00000000000000..717c01f8c7c26c --- /dev/null +++ b/crates/ty_python_semantic/resources/mdtest/bidirectional.md @@ -0,0 +1,78 @@ +# Bidirectional Type Inference + +ty partially supports bidirectional type inference. This is a mechanism for inferring the type of an +expression "from the outside in". Normally, type inference proceeds "from the inside out". That is, +in order to infer the type of an expression, the types of all sub-expressions must first be +inferred. There is no reverse dependency. However, when performing complex type inference, such as +when generics are involved, the type of an outer expression can sometimes be useful in inferring +inner expressions. Bidirectional type inference is a mechanism that propagates such "expected types" +to the inference of inner expressions. + +## Propagating target type annotation + +```toml +[environment] +python-version = "3.12" +``` + +```py +def list1[T](x: T) -> list[T]: + return [x] + +l1 = list1(1) +reveal_type(l1) # revealed: list[Literal[1]] +l2: list[int] = list1(1) +reveal_type(l2) # revealed: list[int] + +# `list[Literal[1]]` and `list[int]` are incompatible, since `list[T]` is invariant in `T`. +# error: [invalid-assignment] "Object of type `list[Literal[1]]` is not assignable to `list[int]`" +l2 = l1 + +intermediate = list1(1) +# TODO: the error will not occur if we can infer the type of `intermediate` to be `list[int]` +# error: [invalid-assignment] "Object of type `list[Literal[1]]` is not assignable to `list[int]`" +l3: list[int] = intermediate +# TODO: it would be nice if this were `list[int]` +reveal_type(intermediate) # revealed: list[Literal[1]] +reveal_type(l3) # revealed: list[int] +``` + +```py +from typing import TypedDict + +class TD(TypedDict): + x: int + +d1 = {"x": 1} +d2: TD = {"x": 1} +d3: dict[str, int] = {"x": 1} + +reveal_type(d1) # revealed: dict[@Todo(dict literal key type), @Todo(dict literal value type)] +reveal_type(d2) # revealed: TD +# TODO: should be `dict[str, int]` +reveal_type(d3) # revealed: dict[@Todo(dict literal key type), @Todo(dict literal value type)] +``` + +## Propagating return type annotation + +```toml +[environment] +python-version = "3.12" +``` + +```py +def list1[T](x: T) -> list[T]: + return [x] + +def get_data() -> dict | None: + return {} + +def wrap_data() -> list[dict]: + if not (res := get_data()): + return list1({}) + reveal_type(list1(res)) # revealed: list[dict[Unknown, Unknown] & ~AlwaysFalsy] + # `list[dict[Unknown, Unknown] & ~AlwaysFalsy]` and `list[dict[Unknown, Unknown]]` are incompatible, + # but the return type check passes here because the inferred return type is widened + # by bidirectional type inference. + return list1(res) +``` diff --git a/crates/ty_python_semantic/src/types/function.rs b/crates/ty_python_semantic/src/types/function.rs index c7b94f0584fccf..8df2256c7a0b25 100644 --- a/crates/ty_python_semantic/src/types/function.rs +++ b/crates/ty_python_semantic/src/types/function.rs @@ -74,7 +74,7 @@ use crate::types::diagnostic::{ }; use crate::types::generics::{GenericContext, walk_generic_context}; use crate::types::narrow::ClassInfoConstraintFunction; -use crate::types::signatures::{CallableSignature, Signature}; +use crate::types::signatures::{CallableSignature, Signature, SignatureFlags}; use crate::types::visitor::any_over_type; use crate::types::{ BoundMethodType, BoundTypeVarInstance, CallableType, ClassBase, ClassLiteral, ClassType, @@ -342,6 +342,32 @@ impl<'db> OverloadLiteral<'db> { self, db: &'db dyn Db, inherited_generic_context: Option>, + ) -> Signature<'db> { + self.signature_impl(db, inherited_generic_context, true) + } + + /// Typed internally-visible "raw" signature for this function. + /// That is, type variables in parameter types and the return type remain non-inferable. + /// + /// ## Warning + /// + /// This uses the semantic index to find the definition of the function. This means that if the + /// calling query is not in the same file as this function is defined in, then this will create + /// a cross-module dependency directly on the full AST which will lead to cache + /// over-invalidation. + pub(crate) fn raw_signature( + self, + db: &'db dyn Db, + inherited_generic_context: Option>, + ) -> Signature<'db> { + self.signature_impl(db, inherited_generic_context, false) + } + + fn signature_impl( + self, + db: &'db dyn Db, + inherited_generic_context: Option>, + mark_typevars_inferable: bool, ) -> Signature<'db> { /// `self` or `cls` can be implicitly positional-only if: /// - It is a method AND @@ -408,23 +434,30 @@ impl<'db> OverloadLiteral<'db> { GenericContext::from_type_params(db, index, definition, type_params) }); let file_scope_id = scope.file_scope_id(db); - let is_generator = file_scope_id.is_generator_function(index); - let has_implicitly_positional_first_parameter = has_implicitly_positional_only_first_param( + + let mut flags = SignatureFlags::empty(); + if file_scope_id.is_generator_function(index) { + flags |= SignatureFlags::IS_GENERATOR; + } + if has_implicitly_positional_only_first_param( db, self, function_stmt_node, file_scope_id, index, - ); - + ) { + flags |= SignatureFlags::HAS_IMPLICITLY_POSITIONAL_FIRST_PARAMETER; + } + if mark_typevars_inferable { + flags |= SignatureFlags::MARK_TYPEVARS_INFERABLE; + } Signature::from_function( db, generic_context, inherited_generic_context, definition, function_stmt_node, - is_generator, - has_implicitly_positional_first_parameter, + flags, ) } @@ -676,6 +709,20 @@ impl<'db> FunctionLiteral<'db> { ) } + /// Typed externally-visible "raw" signature of the last overload or implementation of this function. + /// + /// ## Warning + /// + /// This uses the semantic index to find the definition of the function. This means that if the + /// calling query is not in the same file as this function is defined in, then this will create + /// a cross-module dependency directly on the full AST which will lead to cache + /// over-invalidation. + fn last_definition_raw_signature(self, db: &'db dyn Db) -> Signature<'db> { + let inherited_generic_context = self.inherited_generic_context(db); + self.last_definition(db) + .raw_signature(db, inherited_generic_context) + } + fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self { let context = self .inherited_generic_context(db) @@ -930,6 +977,17 @@ impl<'db> FunctionType<'db> { .last_definition_signature(db, self.type_mappings(db)) } + /// Typed externally-visible "raw" signature of the last overload or implementation of this function. + #[salsa::tracked( + returns(ref), + cycle_fn=last_definition_signature_cycle_recover, + cycle_initial=last_definition_signature_cycle_initial, + heap_size=ruff_memory_usage::heap_size, + )] + pub(crate) fn last_definition_raw_signature(self, db: &'db dyn Db) -> Signature<'db> { + self.literal(db).last_definition_raw_signature(db) + } + /// Convert the `FunctionType` into a [`CallableType`]. pub(crate) fn into_callable_type(self, db: &'db dyn Db) -> CallableType<'db> { CallableType::new(db, self.signature(db), false) diff --git a/crates/ty_python_semantic/src/types/infer.rs b/crates/ty_python_semantic/src/types/infer.rs index 647136b8dbac36..60bd2ba133e2c4 100644 --- a/crates/ty_python_semantic/src/types/infer.rs +++ b/crates/ty_python_semantic/src/types/infer.rs @@ -50,6 +50,7 @@ use crate::semantic_index::expression::Expression; use crate::semantic_index::scope::ScopeId; use crate::semantic_index::{SemanticIndex, semantic_index}; use crate::types::diagnostic::TypeCheckDiagnostics; +use crate::types::function::FunctionType; use crate::types::generics::Specialization; use crate::types::unpacker::{UnpackResult, Unpacker}; use crate::types::{ClassLiteral, KnownClass, Truthiness, Type, TypeAndQualifiers}; @@ -499,6 +500,29 @@ pub(crate) fn nearest_enclosing_class<'db>( }) } +/// Returns the type of the nearest enclosing function for the given scope. +/// +/// This function walks up the ancestor scopes starting from the given scope, +/// and finds the closest (non-lambda) function definition. +/// +/// Returns `None` if no enclosing function is found. +pub(crate) fn nearest_enclosing_function<'db>( + db: &'db dyn Db, + semantic: &SemanticIndex<'db>, + scope: ScopeId, +) -> Option> { + semantic + .ancestor_scopes(scope.file_scope_id(db)) + .find_map(|(_, ancestor_scope)| { + let func = ancestor_scope.node().as_function()?; + let definition = semantic.expect_single_definition(func); + infer_definition_types(db, definition) + .declaration_type(definition) + .inner_type() + .into_function_literal() + }) +} + /// A region within which we can infer types. #[derive(Copy, Clone, Debug)] pub(crate) enum InferenceRegion<'db> { diff --git a/crates/ty_python_semantic/src/types/infer/builder.rs b/crates/ty_python_semantic/src/types/infer/builder.rs index db04778cf8398e..a8f820da9db59e 100644 --- a/crates/ty_python_semantic/src/types/infer/builder.rs +++ b/crates/ty_python_semantic/src/types/infer/builder.rs @@ -75,6 +75,7 @@ use crate::types::function::{ }; use crate::types::generics::{GenericContext, bind_typevar}; use crate::types::generics::{LegacyGenericBase, SpecializationBuilder}; +use crate::types::infer::nearest_enclosing_function; use crate::types::instance::SliceLiteral; use crate::types::mro::MroErrorKind; use crate::types::signatures::Signature; @@ -4770,9 +4771,19 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } fn infer_return_statement(&mut self, ret: &ast::StmtReturn) { - if let Some(ty) = - self.infer_optional_expression(ret.value.as_deref(), TypeContext::default()) - { + let annotated_return_type = |_| { + nearest_enclosing_function(self.db(), self.index, self.scope()).map(|func| { + // When inferring expressions within a function body, + // the expected type passed should be the "raw" type, i.e. type variables in the return type are non-inferable. + TypeContext::new(func.last_definition_raw_signature(self.db()).return_ty) + }) + }; + let tcx = ret + .value + .as_ref() + .and_then(annotated_return_type) + .unwrap_or_default(); + if let Some(ty) = self.infer_optional_expression(ret.value.as_deref(), tcx) { let range = ret .value .as_ref() diff --git a/crates/ty_python_semantic/src/types/signatures.rs b/crates/ty_python_semantic/src/types/signatures.rs index 35da85a09f2bd0..63ae0aa39d8adb 100644 --- a/crates/ty_python_semantic/src/types/signatures.rs +++ b/crates/ty_python_semantic/src/types/signatures.rs @@ -298,6 +298,29 @@ pub(super) fn walk_signature<'db, V: super::visitor::TypeVisitor<'db> + ?Sized>( } } +bitflags::bitflags! { + #[derive(Default, Debug, Copy, Clone)] + pub(crate) struct SignatureFlags: u8 { + const IS_GENERATOR = 1 << 0; + const HAS_IMPLICITLY_POSITIONAL_FIRST_PARAMETER = 1 << 1; + const MARK_TYPEVARS_INFERABLE = 1 << 2; + } +} + +impl SignatureFlags { + pub(crate) fn is_generator(self) -> bool { + self.contains(SignatureFlags::IS_GENERATOR) + } + + pub(crate) fn has_implicitly_positional_first_parameter(self) -> bool { + self.contains(SignatureFlags::HAS_IMPLICITLY_POSITIONAL_FIRST_PARAMETER) + } + + pub(crate) fn mark_typevars_inferable(self) -> bool { + self.contains(SignatureFlags::MARK_TYPEVARS_INFERABLE) + } +} + impl<'db> Signature<'db> { pub(crate) fn new(parameters: Parameters<'db>, return_ty: Option>) -> Self { Self { @@ -354,22 +377,21 @@ impl<'db> Signature<'db> { inherited_generic_context: Option>, definition: Definition<'db>, function_node: &ast::StmtFunctionDef, - is_generator: bool, - has_implicitly_positional_first_parameter: bool, + flags: SignatureFlags, ) -> Self { - let parameters = Parameters::from_parameters( - db, - definition, - function_node.parameters.as_ref(), - has_implicitly_positional_first_parameter, - ); + let parameters = + Parameters::from_parameters(db, definition, function_node.parameters.as_ref(), flags); let return_ty = function_node.returns.as_ref().map(|returns| { - let plain_return_ty = definition_expression_type(db, definition, returns.as_ref()) - .apply_type_mapping( + let plain_return_ty = definition_expression_type(db, definition, returns.as_ref()); + let plain_return_ty = if flags.mark_typevars_inferable() { + plain_return_ty.apply_type_mapping( db, &TypeMapping::MarkTypeVarsInferable(BindingContext::Definition(definition)), - ); - if function_node.is_async && !is_generator { + ) + } else { + plain_return_ty + }; + if function_node.is_async && !flags.is_generator() { KnownClass::CoroutineType .to_specialized_instance(db, [Type::any(), Type::any(), plain_return_ty]) } else { @@ -1168,7 +1190,7 @@ impl<'db> Parameters<'db> { db: &'db dyn Db, definition: Definition<'db>, parameters: &ast::Parameters, - has_implicitly_positional_first_parameter: bool, + flags: SignatureFlags, ) -> Self { let ast::Parameters { posonlyargs, @@ -1195,6 +1217,7 @@ impl<'db> Parameters<'db> { name: Some(param.parameter.name.id.clone()), default_type: default_type(param), }, + flags, ) }; @@ -1207,7 +1230,7 @@ impl<'db> Parameters<'db> { if positional_only.is_empty() { let pos_or_keyword_iter = pos_or_keyword_iter.by_ref(); - if has_implicitly_positional_first_parameter { + if flags.has_implicitly_positional_first_parameter() { positional_only.extend(pos_or_keyword_iter.next().map(pos_only_param)); } @@ -1227,6 +1250,7 @@ impl<'db> Parameters<'db> { name: arg.parameter.name.id.clone(), default_type: default_type(arg), }, + flags, ) }); @@ -1238,6 +1262,7 @@ impl<'db> Parameters<'db> { ParameterKind::Variadic { name: arg.name.id.clone(), }, + flags, ) }); @@ -1250,6 +1275,7 @@ impl<'db> Parameters<'db> { name: arg.parameter.name.id.clone(), default_type: default_type(arg), }, + flags, ) }); @@ -1261,6 +1287,7 @@ impl<'db> Parameters<'db> { ParameterKind::KeywordVariadic { name: arg.name.id.clone(), }, + flags, ) }); @@ -1544,13 +1571,19 @@ impl<'db> Parameter<'db> { definition: Definition<'db>, parameter: &ast::Parameter, kind: ParameterKind<'db>, + flags: SignatureFlags, ) -> Self { Self { annotated_type: parameter.annotation().map(|annotation| { - definition_expression_type(db, definition, annotation).apply_type_mapping( - db, - &TypeMapping::MarkTypeVarsInferable(BindingContext::Definition(definition)), - ) + let annotated_ty = definition_expression_type(db, definition, annotation); + if flags.mark_typevars_inferable() { + annotated_ty.apply_type_mapping( + db, + &TypeMapping::MarkTypeVarsInferable(BindingContext::Definition(definition)), + ) + } else { + annotated_ty + } }), kind, form: ParameterForm::Value, From c6f798c4b0b27fff6676ffb90511593c38a20e65 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Tue, 23 Sep 2025 16:26:05 +0900 Subject: [PATCH 02/26] don't wrap the raw return types of async functions in `CoroutineType` --- .../resources/mdtest/bidirectional.md | 19 +++++++++++++++++++ .../ty_python_semantic/src/types/function.rs | 12 +++++++----- .../src/types/infer/builder.rs | 3 ++- .../src/types/signatures.rs | 7 ++++++- 4 files changed, 34 insertions(+), 7 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/bidirectional.md b/crates/ty_python_semantic/resources/mdtest/bidirectional.md index 717c01f8c7c26c..fb3f1f5dd220f4 100644 --- a/crates/ty_python_semantic/resources/mdtest/bidirectional.md +++ b/crates/ty_python_semantic/resources/mdtest/bidirectional.md @@ -61,6 +61,8 @@ python-version = "3.12" ``` ```py +from typing import overload + def list1[T](x: T) -> list[T]: return [x] @@ -75,4 +77,21 @@ def wrap_data() -> list[dict]: # but the return type check passes here because the inferred return type is widened # by bidirectional type inference. return list1(res) + +@overload +def f(x: int) -> list[int]: ... +@overload +def f(x: str) -> list[str]: ... +def f(x: int | str) -> list[int] | list[str]: + # `list[int] | list[str]` is a different type than `list[int | str]`. + if isinstance(x, int): + return list1(x) + else: + return list1(x) + +reveal_type(f(1)) # revealed: list[int] +reveal_type(f("a")) # revealed: list[str] + +async def f() -> list[int]: + return list1(1) ``` diff --git a/crates/ty_python_semantic/src/types/function.rs b/crates/ty_python_semantic/src/types/function.rs index 8df2256c7a0b25..1c1640ab79b4df 100644 --- a/crates/ty_python_semantic/src/types/function.rs +++ b/crates/ty_python_semantic/src/types/function.rs @@ -343,11 +343,12 @@ impl<'db> OverloadLiteral<'db> { db: &'db dyn Db, inherited_generic_context: Option>, ) -> Signature<'db> { - self.signature_impl(db, inherited_generic_context, true) + self.signature_impl(db, inherited_generic_context, false) } /// Typed internally-visible "raw" signature for this function. - /// That is, type variables in parameter types and the return type remain non-inferable. + /// That is, type variables in parameter types and the return type remain non-inferable, + /// and the return types of async functions are not wrapped in `CoroutineType[...]`. /// /// ## Warning /// @@ -360,14 +361,14 @@ impl<'db> OverloadLiteral<'db> { db: &'db dyn Db, inherited_generic_context: Option>, ) -> Signature<'db> { - self.signature_impl(db, inherited_generic_context, false) + self.signature_impl(db, inherited_generic_context, true) } fn signature_impl( self, db: &'db dyn Db, inherited_generic_context: Option>, - mark_typevars_inferable: bool, + raw: bool, ) -> Signature<'db> { /// `self` or `cls` can be implicitly positional-only if: /// - It is a method AND @@ -448,8 +449,9 @@ impl<'db> OverloadLiteral<'db> { ) { flags |= SignatureFlags::HAS_IMPLICITLY_POSITIONAL_FIRST_PARAMETER; } - if mark_typevars_inferable { + if !raw { flags |= SignatureFlags::MARK_TYPEVARS_INFERABLE; + flags |= SignatureFlags::COROUTINE_RETURN_TYPE; } Signature::from_function( db, diff --git a/crates/ty_python_semantic/src/types/infer/builder.rs b/crates/ty_python_semantic/src/types/infer/builder.rs index a8f820da9db59e..c44843145e968e 100644 --- a/crates/ty_python_semantic/src/types/infer/builder.rs +++ b/crates/ty_python_semantic/src/types/infer/builder.rs @@ -4774,7 +4774,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { let annotated_return_type = |_| { nearest_enclosing_function(self.db(), self.index, self.scope()).map(|func| { // When inferring expressions within a function body, - // the expected type passed should be the "raw" type, i.e. type variables in the return type are non-inferable. + // the expected type passed should be the "raw" type, + // i.e. type variables in the return type are non-inferable, and the return types of async functions are not wrapped in `CoroutineType[...]`. TypeContext::new(func.last_definition_raw_signature(self.db()).return_ty) }) }; diff --git a/crates/ty_python_semantic/src/types/signatures.rs b/crates/ty_python_semantic/src/types/signatures.rs index 63ae0aa39d8adb..26e7c07b622319 100644 --- a/crates/ty_python_semantic/src/types/signatures.rs +++ b/crates/ty_python_semantic/src/types/signatures.rs @@ -304,6 +304,7 @@ bitflags::bitflags! { const IS_GENERATOR = 1 << 0; const HAS_IMPLICITLY_POSITIONAL_FIRST_PARAMETER = 1 << 1; const MARK_TYPEVARS_INFERABLE = 1 << 2; + const COROUTINE_RETURN_TYPE = 1 << 3; } } @@ -319,6 +320,10 @@ impl SignatureFlags { pub(crate) fn mark_typevars_inferable(self) -> bool { self.contains(SignatureFlags::MARK_TYPEVARS_INFERABLE) } + + pub(crate) fn coroutine_return_type(self) -> bool { + self.contains(SignatureFlags::COROUTINE_RETURN_TYPE) + } } impl<'db> Signature<'db> { @@ -391,7 +396,7 @@ impl<'db> Signature<'db> { } else { plain_return_ty }; - if function_node.is_async && !flags.is_generator() { + if function_node.is_async && !flags.is_generator() && flags.coroutine_return_type() { KnownClass::CoroutineType .to_specialized_instance(db, [Type::any(), Type::any(), plain_return_ty]) } else { From d14dcc87791e70c06dd688551a28f391961c16f9 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Tue, 23 Sep 2025 23:48:59 +0900 Subject: [PATCH 03/26] improve `SpecializationBuilder::infer` behavior when `formal` is a union type --- .../resources/mdtest/bidirectional.md | 11 ++++++++++- crates/ty_python_semantic/src/types/generics.rs | 10 +++++++++- 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/bidirectional.md b/crates/ty_python_semantic/resources/mdtest/bidirectional.md index fb3f1f5dd220f4..603970004f768e 100644 --- a/crates/ty_python_semantic/resources/mdtest/bidirectional.md +++ b/crates/ty_python_semantic/resources/mdtest/bidirectional.md @@ -35,6 +35,9 @@ l3: list[int] = intermediate # TODO: it would be nice if this were `list[int]` reveal_type(intermediate) # revealed: list[Literal[1]] reveal_type(l3) # revealed: list[int] + +l4: list[int] | None = list1(1) +reveal_type(l4) # revealed: list[int] ``` ```py @@ -78,6 +81,12 @@ def wrap_data() -> list[dict]: # by bidirectional type inference. return list1(res) +def wrap_data2() -> list[dict] | None: + if not (res := get_data()): + return None + reveal_type(list1(res)) # revealed: list[dict[Unknown, Unknown] & ~AlwaysFalsy] + return list1(res) + @overload def f(x: int) -> list[int]: ... @overload @@ -92,6 +101,6 @@ def f(x: int | str) -> list[int] | list[str]: reveal_type(f(1)) # revealed: list[int] reveal_type(f("a")) # revealed: list[str] -async def f() -> list[int]: +async def g() -> list[int]: return list1(1) ``` diff --git a/crates/ty_python_semantic/src/types/generics.rs b/crates/ty_python_semantic/src/types/generics.rs index 1677a7ea6feb0d..93cbe0dc4432b9 100644 --- a/crates/ty_python_semantic/src/types/generics.rs +++ b/crates/ty_python_semantic/src/types/generics.rs @@ -1038,7 +1038,7 @@ impl<'db> SpecializationBuilder<'db> { pub(crate) fn infer( &mut self, formal: Type<'db>, - actual: Type<'db>, + mut actual: Type<'db>, ) -> Result<(), SpecializationError<'db>> { if formal == actual { return Ok(()); @@ -1067,6 +1067,14 @@ impl<'db> SpecializationBuilder<'db> { return Ok(()); } + if let Type::Union(union) = actual { + // For example, if `formal` is `list[T]` and `actual` is `list[int] | None`, we want to specialize `T` to `int`. + // So, here we remove the union elements that are not related to `formal`. + actual = union.filter(self.db, |actual_elem| { + !actual_elem.is_disjoint_from(self.db, formal) + }); + } + match (formal, actual) { (Type::Union(formal), _) => { // TODO: We haven't implemented a full unification solver yet. If typevars appear From d33ee55e0d1e71259ea00af3a4e128bafc21481e Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Wed, 24 Sep 2025 13:24:59 +0900 Subject: [PATCH 04/26] fix `nearest_enclosing_function` returning incorrect types for decorated functions --- .../resources/mdtest/bidirectional.md | 14 +++++++++++++- crates/ty_python_semantic/src/types/infer.rs | 7 ++++--- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/bidirectional.md b/crates/ty_python_semantic/resources/mdtest/bidirectional.md index 603970004f768e..6dddac7d4f4770 100644 --- a/crates/ty_python_semantic/resources/mdtest/bidirectional.md +++ b/crates/ty_python_semantic/resources/mdtest/bidirectional.md @@ -64,7 +64,7 @@ python-version = "3.12" ``` ```py -from typing import overload +from typing import overload, Callable def list1[T](x: T) -> list[T]: return [x] @@ -87,6 +87,18 @@ def wrap_data2() -> list[dict] | None: reveal_type(list1(res)) # revealed: list[dict[Unknown, Unknown] & ~AlwaysFalsy] return list1(res) +def deco[T](func: Callable[[], T]) -> Callable[[], T]: + return func + +def outer() -> Callable[[], list[dict]]: + @deco + def inner() -> list[dict]: + if not (res := get_data()): + return list1({}) + reveal_type(list1(res)) # revealed: list[dict[Unknown, Unknown] & ~AlwaysFalsy] + return list1(res) + return inner + @overload def f(x: int) -> list[int]: ... @overload diff --git a/crates/ty_python_semantic/src/types/infer.rs b/crates/ty_python_semantic/src/types/infer.rs index 60bd2ba133e2c4..94a74232889f99 100644 --- a/crates/ty_python_semantic/src/types/infer.rs +++ b/crates/ty_python_semantic/src/types/infer.rs @@ -516,9 +516,10 @@ pub(crate) fn nearest_enclosing_function<'db>( .find_map(|(_, ancestor_scope)| { let func = ancestor_scope.node().as_function()?; let definition = semantic.expect_single_definition(func); - infer_definition_types(db, definition) - .declaration_type(definition) - .inner_type() + let inference = infer_definition_types(db, definition); + inference + .undecorated_type() + .unwrap_or_else(|| inference.declaration_type(definition).inner_type()) .into_function_literal() }) } From 962284112e2cb3491f9b2d4ef7f135d5ec1340ec Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Wed, 24 Sep 2025 14:37:33 +0900 Subject: [PATCH 05/26] Update bidirectional.md --- .../ty_python_semantic/resources/mdtest/bidirectional.md | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/crates/ty_python_semantic/resources/mdtest/bidirectional.md b/crates/ty_python_semantic/resources/mdtest/bidirectional.md index 6dddac7d4f4770..3bf1141d6bdbed 100644 --- a/crates/ty_python_semantic/resources/mdtest/bidirectional.md +++ b/crates/ty_python_semantic/resources/mdtest/bidirectional.md @@ -38,6 +38,14 @@ reveal_type(l3) # revealed: list[int] l4: list[int] | None = list1(1) reveal_type(l4) # revealed: list[int] + +def _(l: list[int] | None = None): + l1 = l or list() + reveal_type(l1) # revealed: (list[int] & ~AlwaysFalsy) | list[Unknown] + + l2: list[int] = l or list() + # TODO: should be `list[int]` + reveal_type(l2) # revealed: (list[int] & ~AlwaysFalsy) | list[Unknown] ``` ```py From ef75f0a9ace043c3bc956b34c411ce696706ef01 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Wed, 24 Sep 2025 15:48:29 +0900 Subject: [PATCH 06/26] prevent incorrect specializations in `SpecializationBuilder::infer` --- .../resources/mdtest/bidirectional.md | 15 ++++++++++++++- .../resources/mdtest/generics/legacy/functions.md | 4 ++++ .../resources/mdtest/generics/pep695/functions.md | 4 ++++ crates/ty_python_semantic/src/types/generics.rs | 5 +++++ 4 files changed, 27 insertions(+), 1 deletion(-) diff --git a/crates/ty_python_semantic/resources/mdtest/bidirectional.md b/crates/ty_python_semantic/resources/mdtest/bidirectional.md index 3bf1141d6bdbed..a79cbbb15b0007 100644 --- a/crates/ty_python_semantic/resources/mdtest/bidirectional.md +++ b/crates/ty_python_semantic/resources/mdtest/bidirectional.md @@ -1,4 +1,4 @@ -# Bidirectional Type Inference +# Bidirectional type inference ty partially supports bidirectional type inference. This is a mechanism for inferring the type of an expression "from the outside in". Normally, type inference proceeds "from the inside out". That is, @@ -46,6 +46,13 @@ def _(l: list[int] | None = None): l2: list[int] = l or list() # TODO: should be `list[int]` reveal_type(l2) # revealed: (list[int] & ~AlwaysFalsy) | list[Unknown] + +def f[T](x: T, cond: bool) -> T | list[T]: + return x if cond else [x] + +# TODO: no error +# error: [invalid-assignment] "Object of type `Literal[1] | list[Literal[1]]` is not assignable to `int | list[int]`" +l5: int | list[int] = f(1, True) ``` ```py @@ -123,4 +130,10 @@ reveal_type(f("a")) # revealed: list[str] async def g() -> list[int]: return list1(1) + +def h[T](x: T, cond: bool) -> T | list[T]: + return i(x, cond) + +def i[T](x: T, cond: bool) -> T | list[T]: + return x if cond else [x] ``` diff --git a/crates/ty_python_semantic/resources/mdtest/generics/legacy/functions.md b/crates/ty_python_semantic/resources/mdtest/generics/legacy/functions.md index 3175ed7216dced..144da51c9affd3 100644 --- a/crates/ty_python_semantic/resources/mdtest/generics/legacy/functions.md +++ b/crates/ty_python_semantic/resources/mdtest/generics/legacy/functions.md @@ -323,6 +323,10 @@ def union_param(x: T | None) -> T: reveal_type(union_param("a")) # revealed: Literal["a"] reveal_type(union_param(1)) # revealed: Literal[1] reveal_type(union_param(None)) # revealed: Unknown + +def _(x: int | None): + # TODO: should be `int` + reveal_type(union_param(x)) # revealed: Unknown ``` ```py diff --git a/crates/ty_python_semantic/resources/mdtest/generics/pep695/functions.md b/crates/ty_python_semantic/resources/mdtest/generics/pep695/functions.md index a9224d46c800a3..6d597636543de1 100644 --- a/crates/ty_python_semantic/resources/mdtest/generics/pep695/functions.md +++ b/crates/ty_python_semantic/resources/mdtest/generics/pep695/functions.md @@ -286,6 +286,10 @@ def union_param[T](x: T | None) -> T: reveal_type(union_param("a")) # revealed: Literal["a"] reveal_type(union_param(1)) # revealed: Literal[1] reveal_type(union_param(None)) # revealed: Unknown + +def _(x: int | None): + # TODO: should be `int` + reveal_type(union_param(x)) # revealed: Unknown ``` ```py diff --git a/crates/ty_python_semantic/src/types/generics.rs b/crates/ty_python_semantic/src/types/generics.rs index e31d760155c5d9..af24814744beea 100644 --- a/crates/ty_python_semantic/src/types/generics.rs +++ b/crates/ty_python_semantic/src/types/generics.rs @@ -1088,6 +1088,11 @@ impl<'db> SpecializationBuilder<'db> { } match (formal, actual) { + (Type::Union(_), Type::Union(_)) => { + // TODO: We need to infer specializations appropriately. + // e.g. + // `formal: list[T] | T | U, actual: V | int | list[V]` => `T = V, U = int` + } (Type::Union(formal), _) => { // TODO: We haven't implemented a full unification solver yet. If typevars appear // in multiple union elements, we ideally want to express that _only one_ of them From 6c5625b75cc2714260174faddd894171eac8f558 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama <45118249+mtshiba@users.noreply.github.com> Date: Sat, 27 Sep 2025 01:03:07 +0900 Subject: [PATCH 07/26] Update crates/ty_python_semantic/src/types/infer/builder.rs Co-authored-by: Alex Waygood --- crates/ty_python_semantic/src/types/infer/builder.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/crates/ty_python_semantic/src/types/infer/builder.rs b/crates/ty_python_semantic/src/types/infer/builder.rs index 511270487c9a89..cd5f72160eafcf 100644 --- a/crates/ty_python_semantic/src/types/infer/builder.rs +++ b/crates/ty_python_semantic/src/types/infer/builder.rs @@ -4775,7 +4775,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { nearest_enclosing_function(self.db(), self.index, self.scope()).map(|func| { // When inferring expressions within a function body, // the expected type passed should be the "raw" type, - // i.e. type variables in the return type are non-inferable, and the return types of async functions are not wrapped in `CoroutineType[...]`. + // i.e. type variables in the return type are non-inferable, + // and the return types of async functions are not wrapped in `CoroutineType[...]`. TypeContext::new(func.last_definition_raw_signature(self.db()).return_ty) }) }; From 15652804652057071be03296094c1ff7a24640c0 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama <45118249+mtshiba@users.noreply.github.com> Date: Sat, 27 Sep 2025 01:03:18 +0900 Subject: [PATCH 08/26] Update crates/ty_python_semantic/resources/mdtest/bidirectional.md Co-authored-by: Alex Waygood --- crates/ty_python_semantic/resources/mdtest/bidirectional.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/ty_python_semantic/resources/mdtest/bidirectional.md b/crates/ty_python_semantic/resources/mdtest/bidirectional.md index a79cbbb15b0007..aec7d47935f8ec 100644 --- a/crates/ty_python_semantic/resources/mdtest/bidirectional.md +++ b/crates/ty_python_semantic/resources/mdtest/bidirectional.md @@ -119,7 +119,7 @@ def f(x: int) -> list[int]: ... @overload def f(x: str) -> list[str]: ... def f(x: int | str) -> list[int] | list[str]: - # `list[int] | list[str]` is a different type than `list[int | str]`. + # `list[int] | list[str]` is disjoint from `list[int | str]`. if isinstance(x, int): return list1(x) else: From 7e8595aaaf7b3923cfe9c8b31d6b403f1aca1a07 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama <45118249+mtshiba@users.noreply.github.com> Date: Sat, 27 Sep 2025 01:03:47 +0900 Subject: [PATCH 09/26] Update crates/ty_python_semantic/resources/mdtest/bidirectional.md Co-authored-by: Alex Waygood --- crates/ty_python_semantic/resources/mdtest/bidirectional.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/crates/ty_python_semantic/resources/mdtest/bidirectional.md b/crates/ty_python_semantic/resources/mdtest/bidirectional.md index aec7d47935f8ec..33560eef4fda5a 100644 --- a/crates/ty_python_semantic/resources/mdtest/bidirectional.md +++ b/crates/ty_python_semantic/resources/mdtest/bidirectional.md @@ -55,6 +55,8 @@ def f[T](x: T, cond: bool) -> T | list[T]: l5: int | list[int] = f(1, True) ``` +`typed_dict.py`: + ```py from typing import TypedDict From d666494f1c91155e0dc16e518b94bf0af1fcf7d7 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Sat, 27 Sep 2025 04:13:47 +0900 Subject: [PATCH 10/26] refactor according to the review --- .../ty_python_semantic/src/types/function.rs | 62 +++++---- .../ty_python_semantic/src/types/generics.rs | 23 ++++ .../src/types/signatures.rs | 130 ++++++------------ 3 files changed, 99 insertions(+), 116 deletions(-) diff --git a/crates/ty_python_semantic/src/types/function.rs b/crates/ty_python_semantic/src/types/function.rs index 1c1640ab79b4df..7c33551dc4217c 100644 --- a/crates/ty_python_semantic/src/types/function.rs +++ b/crates/ty_python_semantic/src/types/function.rs @@ -74,7 +74,7 @@ use crate::types::diagnostic::{ }; use crate::types::generics::{GenericContext, walk_generic_context}; use crate::types::narrow::ClassInfoConstraintFunction; -use crate::types::signatures::{CallableSignature, Signature, SignatureFlags}; +use crate::types::signatures::{CallableSignature, Signature}; use crate::types::visitor::any_over_type; use crate::types::{ BoundMethodType, BoundTypeVarInstance, CallableType, ClassBase, ClassLiteral, ClassType, @@ -343,7 +343,33 @@ impl<'db> OverloadLiteral<'db> { db: &'db dyn Db, inherited_generic_context: Option>, ) -> Signature<'db> { - self.signature_impl(db, inherited_generic_context, false) + let mut signature = self.raw_signature(db, inherited_generic_context); + + let scope = self.body_scope(db); + let module = parsed_module(db, self.file(db)).load(db); + let function_node = scope.node(db).expect_function().node(&module); + let index = semantic_index(db, scope.file(db)); + let file_scope_id = scope.file_scope_id(db); + let is_generator = file_scope_id.is_generator_function(index); + + if function_node.is_async && !is_generator { + signature = signature.wrap_coroutine_return_type(db); + } + signature = signature.mark_typevars_inferable(db); + + let pep695_ctx = function_node.type_params.as_ref().map(|type_params| { + GenericContext::from_type_params(db, index, self.definition(db), type_params) + }); + let legacy_ctx = GenericContext::from_function_params( + db, + self.definition(db), + signature.parameters(), + signature.return_ty, + ); + signature.generic_context = + GenericContext::merge_pep695_and_legacy(db, pep695_ctx, legacy_ctx); + + signature } /// Typed internally-visible "raw" signature for this function. @@ -356,19 +382,10 @@ impl<'db> OverloadLiteral<'db> { /// calling query is not in the same file as this function is defined in, then this will create /// a cross-module dependency directly on the full AST which will lead to cache /// over-invalidation. - pub(crate) fn raw_signature( + fn raw_signature( self, db: &'db dyn Db, inherited_generic_context: Option>, - ) -> Signature<'db> { - self.signature_impl(db, inherited_generic_context, true) - } - - fn signature_impl( - self, - db: &'db dyn Db, - inherited_generic_context: Option>, - raw: bool, ) -> Signature<'db> { /// `self` or `cls` can be implicitly positional-only if: /// - It is a method AND @@ -431,35 +448,26 @@ impl<'db> OverloadLiteral<'db> { let function_stmt_node = scope.node(db).expect_function().node(&module); let definition = self.definition(db); let index = semantic_index(db, scope.file(db)); - let generic_context = function_stmt_node.type_params.as_ref().map(|type_params| { + let pep695_ctx = function_stmt_node.type_params.as_ref().map(|type_params| { GenericContext::from_type_params(db, index, definition, type_params) }); let file_scope_id = scope.file_scope_id(db); - let mut flags = SignatureFlags::empty(); - if file_scope_id.is_generator_function(index) { - flags |= SignatureFlags::IS_GENERATOR; - } - if has_implicitly_positional_only_first_param( + let has_implicitly_positional_first_parameter = has_implicitly_positional_only_first_param( db, self, function_stmt_node, file_scope_id, index, - ) { - flags |= SignatureFlags::HAS_IMPLICITLY_POSITIONAL_FIRST_PARAMETER; - } - if !raw { - flags |= SignatureFlags::MARK_TYPEVARS_INFERABLE; - flags |= SignatureFlags::COROUTINE_RETURN_TYPE; - } + ); + Signature::from_function( db, - generic_context, + pep695_ctx, inherited_generic_context, definition, function_stmt_node, - flags, + has_implicitly_positional_first_parameter, ) } diff --git a/crates/ty_python_semantic/src/types/generics.rs b/crates/ty_python_semantic/src/types/generics.rs index af24814744beea..73ad82c9b20061 100644 --- a/crates/ty_python_semantic/src/types/generics.rs +++ b/crates/ty_python_semantic/src/types/generics.rs @@ -216,6 +216,29 @@ impl<'db> GenericContext<'db> { Some(Self::new(db, variables)) } + pub(crate) fn merge_pep695_and_legacy( + db: &'db dyn Db, + pep695_generic_context: Option, + legacy_generic_context: Option, + ) -> Option { + match (legacy_generic_context, pep695_generic_context) { + (Some(legacy_ctx), Some(ctx)) => { + if legacy_ctx + .variables(db) + .iter() + .exactly_one() + .is_ok_and(|bound_typevar| bound_typevar.typevar(db).is_self(db)) + { + Some(legacy_ctx.merge(db, ctx)) + } else { + // TODO: Raise a diagnostic — mixing PEP 695 and legacy typevars is not allowed + Some(ctx) + } + } + (left, right) => left.or(right), + } + } + /// Creates a generic context from the legacy `TypeVar`s that appear in class's base class /// list. pub(crate) fn from_base_classes( diff --git a/crates/ty_python_semantic/src/types/signatures.rs b/crates/ty_python_semantic/src/types/signatures.rs index 6f93ecfa09a1ab..d85ddd3ecca42b 100644 --- a/crates/ty_python_semantic/src/types/signatures.rs +++ b/crates/ty_python_semantic/src/types/signatures.rs @@ -298,34 +298,6 @@ pub(super) fn walk_signature<'db, V: super::visitor::TypeVisitor<'db> + ?Sized>( } } -bitflags::bitflags! { - #[derive(Default, Debug, Copy, Clone)] - pub(crate) struct SignatureFlags: u8 { - const IS_GENERATOR = 1 << 0; - const HAS_IMPLICITLY_POSITIONAL_FIRST_PARAMETER = 1 << 1; - const MARK_TYPEVARS_INFERABLE = 1 << 2; - const COROUTINE_RETURN_TYPE = 1 << 3; - } -} - -impl SignatureFlags { - pub(crate) fn is_generator(self) -> bool { - self.contains(SignatureFlags::IS_GENERATOR) - } - - pub(crate) fn has_implicitly_positional_first_parameter(self) -> bool { - self.contains(SignatureFlags::HAS_IMPLICITLY_POSITIONAL_FIRST_PARAMETER) - } - - pub(crate) fn mark_typevars_inferable(self) -> bool { - self.contains(SignatureFlags::MARK_TYPEVARS_INFERABLE) - } - - pub(crate) fn coroutine_return_type(self) -> bool { - self.contains(SignatureFlags::COROUTINE_RETURN_TYPE) - } -} - impl<'db> Signature<'db> { pub(crate) fn new(parameters: Parameters<'db>, return_ty: Option>) -> Self { Self { @@ -378,52 +350,29 @@ impl<'db> Signature<'db> { /// Return a typed signature from a function definition. pub(super) fn from_function( db: &'db dyn Db, - generic_context: Option>, + pep695_generic_context: Option>, inherited_generic_context: Option>, definition: Definition<'db>, function_node: &ast::StmtFunctionDef, - flags: SignatureFlags, + has_implicitly_positional_first_parameter: bool, ) -> Self { - let parameters = - Parameters::from_parameters(db, definition, function_node.parameters.as_ref(), flags); - let return_ty = function_node.returns.as_ref().map(|returns| { - let plain_return_ty = definition_expression_type(db, definition, returns.as_ref()); - let plain_return_ty = if flags.mark_typevars_inferable() { - plain_return_ty.apply_type_mapping( - db, - &TypeMapping::MarkTypeVarsInferable(Some(BindingContext::Definition( - definition, - ))), - ) - } else { - plain_return_ty - }; - if function_node.is_async && !flags.is_generator() && flags.coroutine_return_type() { - KnownClass::CoroutineType - .to_specialized_instance(db, [Type::any(), Type::any(), plain_return_ty]) - } else { - plain_return_ty - } - }); + let parameters = Parameters::from_parameters( + db, + definition, + function_node.parameters.as_ref(), + has_implicitly_positional_first_parameter, + ); + let return_ty = function_node + .returns + .as_ref() + .map(|returns| definition_expression_type(db, definition, returns.as_ref())); let legacy_generic_context = GenericContext::from_function_params(db, definition, ¶meters, return_ty); - - let full_generic_context = match (legacy_generic_context, generic_context) { - (Some(legacy_ctx), Some(ctx)) => { - if legacy_ctx - .variables(db) - .iter() - .exactly_one() - .is_ok_and(|bound_typevar| bound_typevar.typevar(db).is_self(db)) - { - Some(legacy_ctx.merge(db, ctx)) - } else { - // TODO: Raise a diagnostic — mixing PEP 695 and legacy typevars is not allowed - Some(ctx) - } - } - (left, right) => left.or(right), - }; + let full_generic_context = GenericContext::merge_pep695_and_legacy( + db, + pep695_generic_context, + legacy_generic_context, + ); Self { generic_context: full_generic_context, @@ -434,6 +383,25 @@ impl<'db> Signature<'db> { } } + pub(super) fn mark_typevars_inferable(self, db: &'db dyn Db) -> Self { + if let Some(definition) = self.definition { + self.apply_type_mapping( + db, + &TypeMapping::MarkTypeVarsInferable(Some(BindingContext::Definition(definition))), + ) + } else { + self + } + } + + pub(super) fn wrap_coroutine_return_type(self, db: &'db dyn Db) -> Self { + let return_ty = self.return_ty.map(|return_ty| { + KnownClass::CoroutineType + .to_specialized_instance(db, [Type::any(), Type::any(), return_ty]) + }); + Self { return_ty, ..self } + } + /// Returns the signature which accepts any parameters and returns an `Unknown` type. pub(crate) fn unknown() -> Self { Self::new(Parameters::unknown(), Some(Type::unknown())) @@ -1197,7 +1165,7 @@ impl<'db> Parameters<'db> { db: &'db dyn Db, definition: Definition<'db>, parameters: &ast::Parameters, - flags: SignatureFlags, + has_implicitly_positional_first_parameter: bool, ) -> Self { let ast::Parameters { posonlyargs, @@ -1224,7 +1192,6 @@ impl<'db> Parameters<'db> { name: Some(param.parameter.name.id.clone()), default_type: default_type(param), }, - flags, ) }; @@ -1237,7 +1204,7 @@ impl<'db> Parameters<'db> { if positional_only.is_empty() { let pos_or_keyword_iter = pos_or_keyword_iter.by_ref(); - if flags.has_implicitly_positional_first_parameter() { + if has_implicitly_positional_first_parameter { positional_only.extend(pos_or_keyword_iter.next().map(pos_only_param)); } @@ -1257,7 +1224,6 @@ impl<'db> Parameters<'db> { name: arg.parameter.name.id.clone(), default_type: default_type(arg), }, - flags, ) }); @@ -1269,7 +1235,6 @@ impl<'db> Parameters<'db> { ParameterKind::Variadic { name: arg.name.id.clone(), }, - flags, ) }); @@ -1282,7 +1247,6 @@ impl<'db> Parameters<'db> { name: arg.parameter.name.id.clone(), default_type: default_type(arg), }, - flags, ) }); @@ -1294,7 +1258,6 @@ impl<'db> Parameters<'db> { ParameterKind::KeywordVariadic { name: arg.name.id.clone(), }, - flags, ) }); @@ -1589,22 +1552,11 @@ impl<'db> Parameter<'db> { definition: Definition<'db>, parameter: &ast::Parameter, kind: ParameterKind<'db>, - flags: SignatureFlags, ) -> Self { Self { - annotated_type: parameter.annotation().map(|annotation| { - let annotated_ty = definition_expression_type(db, definition, annotation); - if flags.mark_typevars_inferable() { - annotated_ty.apply_type_mapping( - db, - &TypeMapping::MarkTypeVarsInferable(Some(BindingContext::Definition( - definition, - ))), - ) - } else { - annotated_ty - } - }), + annotated_type: parameter + .annotation() + .map(|annotation| definition_expression_type(db, definition, annotation)), kind, form: ParameterForm::Value, } From 40daa3f607733a26546d138e2d23bcee59417f20 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Tue, 30 Sep 2025 01:53:01 +0900 Subject: [PATCH 11/26] Update bidirectional.md --- crates/ty_python_semantic/resources/mdtest/bidirectional.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/bidirectional.md b/crates/ty_python_semantic/resources/mdtest/bidirectional.md index 33560eef4fda5a..5945240ab70765 100644 --- a/crates/ty_python_semantic/resources/mdtest/bidirectional.md +++ b/crates/ty_python_semantic/resources/mdtest/bidirectional.md @@ -44,7 +44,7 @@ def _(l: list[int] | None = None): reveal_type(l1) # revealed: (list[int] & ~AlwaysFalsy) | list[Unknown] l2: list[int] = l or list() - # TODO: should be `list[int]` + # TODO: it would be nice if this were `list[int]` reveal_type(l2) # revealed: (list[int] & ~AlwaysFalsy) | list[Unknown] def f[T](x: T, cond: bool) -> T | list[T]: @@ -94,8 +94,8 @@ def wrap_data() -> list[dict]: return list1({}) reveal_type(list1(res)) # revealed: list[dict[Unknown, Unknown] & ~AlwaysFalsy] # `list[dict[Unknown, Unknown] & ~AlwaysFalsy]` and `list[dict[Unknown, Unknown]]` are incompatible, - # but the return type check passes here because the inferred return type is widened - # by bidirectional type inference. + # but the return type check passes here because the type of `list1(res)` is inferred + # by bidirectional type inference using the annotated return type, and the type of `res` is not used. return list1(res) def wrap_data2() -> list[dict] | None: From b0a62f17bcded3f2b3348df7400301813d6cf8e4 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Tue, 30 Sep 2025 02:10:27 +0900 Subject: [PATCH 12/26] Update signatures.rs --- crates/ty_python_semantic/src/types/signatures.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/crates/ty_python_semantic/src/types/signatures.rs b/crates/ty_python_semantic/src/types/signatures.rs index dcab211dea321a..d01208a120a713 100644 --- a/crates/ty_python_semantic/src/types/signatures.rs +++ b/crates/ty_python_semantic/src/types/signatures.rs @@ -396,9 +396,10 @@ impl<'db> Signature<'db> { pub(super) fn mark_typevars_inferable(self, db: &'db dyn Db) -> Self { if let Some(definition) = self.definition { - self.apply_type_mapping( + self.apply_type_mapping_impl( db, &TypeMapping::MarkTypeVarsInferable(Some(BindingContext::Definition(definition))), + &ApplyTypeMappingVisitor::default(), ) } else { self From 4211ac92569b3e445e68c7326aedf7f3a8ffa128 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Tue, 30 Sep 2025 02:13:50 +0900 Subject: [PATCH 13/26] Update bidirectional.md --- crates/ty_python_semantic/resources/mdtest/bidirectional.md | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/bidirectional.md b/crates/ty_python_semantic/resources/mdtest/bidirectional.md index 5945240ab70765..1fe622902d597b 100644 --- a/crates/ty_python_semantic/resources/mdtest/bidirectional.md +++ b/crates/ty_python_semantic/resources/mdtest/bidirectional.md @@ -67,10 +67,9 @@ d1 = {"x": 1} d2: TD = {"x": 1} d3: dict[str, int] = {"x": 1} -reveal_type(d1) # revealed: dict[@Todo(dict literal key type), @Todo(dict literal value type)] +reveal_type(d1) # revealed: dict[Unknown | str, Unknown | int] reveal_type(d2) # revealed: TD -# TODO: should be `dict[str, int]` -reveal_type(d3) # revealed: dict[@Todo(dict literal key type), @Todo(dict literal value type)] +reveal_type(d3) # revealed: dict[str, int] ``` ## Propagating return type annotation From 2212cc7605145c37c84a250ebdd84e5870452c07 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Tue, 30 Sep 2025 03:59:52 +0900 Subject: [PATCH 14/26] Update bidirectional.md --- crates/ty_python_semantic/resources/mdtest/bidirectional.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/ty_python_semantic/resources/mdtest/bidirectional.md b/crates/ty_python_semantic/resources/mdtest/bidirectional.md index 1fe622902d597b..2510acc315c938 100644 --- a/crates/ty_python_semantic/resources/mdtest/bidirectional.md +++ b/crates/ty_python_semantic/resources/mdtest/bidirectional.md @@ -44,7 +44,7 @@ def _(l: list[int] | None = None): reveal_type(l1) # revealed: (list[int] & ~AlwaysFalsy) | list[Unknown] l2: list[int] = l or list() - # TODO: it would be nice if this were `list[int]` + # it would be better if this were `list[int]`? (https://github.com/astral-sh/ty/issues/136) reveal_type(l2) # revealed: (list[int] & ~AlwaysFalsy) | list[Unknown] def f[T](x: T, cond: bool) -> T | list[T]: From 819a415a2b3e00a590507b1bb1063da06056a16a Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Tue, 7 Oct 2025 00:16:59 +0900 Subject: [PATCH 15/26] Update generics.rs --- crates/ty_python_semantic/src/types/generics.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/crates/ty_python_semantic/src/types/generics.rs b/crates/ty_python_semantic/src/types/generics.rs index 77250a67c39162..455fc49c528d9f 100644 --- a/crates/ty_python_semantic/src/types/generics.rs +++ b/crates/ty_python_semantic/src/types/generics.rs @@ -298,7 +298,6 @@ impl<'db> GenericContext<'db> { (Some(legacy_ctx), Some(ctx)) => { if legacy_ctx .variables(db) - .iter() .exactly_one() .is_ok_and(|bound_typevar| bound_typevar.typevar(db).is_self(db)) { From 1b6a5059600f9fcbc42bc768b97134cdcdccb644 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama <45118249+mtshiba@users.noreply.github.com> Date: Tue, 7 Oct 2025 00:20:08 +0900 Subject: [PATCH 16/26] Apply suggestion from @ibraheemdev Co-Authored-By: Ibraheem Ahmed --- .../src/types/infer/builder.rs | 25 +++++++++---------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/crates/ty_python_semantic/src/types/infer/builder.rs b/crates/ty_python_semantic/src/types/infer/builder.rs index b364f60b5ff0dc..3695dc5a73a88e 100644 --- a/crates/ty_python_semantic/src/types/infer/builder.rs +++ b/crates/ty_python_semantic/src/types/infer/builder.rs @@ -4754,20 +4754,19 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } fn infer_return_statement(&mut self, ret: &ast::StmtReturn) { - let annotated_return_type = |_| { - nearest_enclosing_function(self.db(), self.index, self.scope()).map(|func| { - // When inferring expressions within a function body, - // the expected type passed should be the "raw" type, - // i.e. type variables in the return type are non-inferable, - // and the return types of async functions are not wrapped in `CoroutineType[...]`. - TypeContext::new(func.last_definition_raw_signature(self.db()).return_ty) - }) + let tcx = if ret.value.is_some() { + nearest_enclosing_function(self.db(), self.index, self.scope()) + .map(|func| { + // When inferring expressions within a function body, + // the expected type passed should be the "raw" type, + // i.e. type variables in the return type are non-inferable, + // and the return types of async functions are not wrapped in `CoroutineType[...]`. + TypeContext::new(func.last_definition_raw_signature(self.db()).return_ty) + }) + .unwrap_or_default() + } else { + TypeContext::default() }; - let tcx = ret - .value - .as_ref() - .and_then(annotated_return_type) - .unwrap_or_default(); if let Some(ty) = self.infer_optional_expression(ret.value.as_deref(), tcx) { let range = ret .value From b27c1ed22de63206cd4d961fc7a7b2e3d78148e7 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama <45118249+mtshiba@users.noreply.github.com> Date: Tue, 7 Oct 2025 00:27:24 +0900 Subject: [PATCH 17/26] Apply suggestions from code review Co-authored-by: Ibraheem Ahmed --- crates/ty_python_semantic/resources/mdtest/bidirectional.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/bidirectional.md b/crates/ty_python_semantic/resources/mdtest/bidirectional.md index 2510acc315c938..d8171b25f1024b 100644 --- a/crates/ty_python_semantic/resources/mdtest/bidirectional.md +++ b/crates/ty_python_semantic/resources/mdtest/bidirectional.md @@ -36,8 +36,8 @@ l3: list[int] = intermediate reveal_type(intermediate) # revealed: list[Literal[1]] reveal_type(l3) # revealed: list[int] -l4: list[int] | None = list1(1) -reveal_type(l4) # revealed: list[int] +l4: list[int | str] | None = list1(1) +reveal_type(l4) # revealed: list[int | str] def _(l: list[int] | None = None): l1 = l or list() @@ -129,7 +129,7 @@ def f(x: int | str) -> list[int] | list[str]: reveal_type(f(1)) # revealed: list[int] reveal_type(f("a")) # revealed: list[str] -async def g() -> list[int]: +async def g() -> list[int | str]: return list1(1) def h[T](x: T, cond: bool) -> T | list[T]: From 0249ba238284b15a3f7ea966086ee01c42f62965 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Tue, 7 Oct 2025 01:30:27 +0900 Subject: [PATCH 18/26] improve bidirectional inference in `infer_collection_literal` --- .../mdtest/assignment/annotations.md | 76 ++++++++++++++++++- crates/ty_python_semantic/src/types.rs | 9 +++ .../ty_python_semantic/src/types/generics.rs | 10 +-- crates/ty_python_semantic/src/types/infer.rs | 6 ++ .../src/types/infer/builder.rs | 7 ++ 5 files changed, 100 insertions(+), 8 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/assignment/annotations.md b/crates/ty_python_semantic/resources/mdtest/assignment/annotations.md index 8a0eb110e53e22..12812db53a7565 100644 --- a/crates/ty_python_semantic/resources/mdtest/assignment/annotations.md +++ b/crates/ty_python_semantic/resources/mdtest/assignment/annotations.md @@ -150,7 +150,81 @@ r: dict[int | str, int | str] = {1: 1, 2: 2, 3: 3} reveal_type(r) # revealed: dict[int | str, int | str] ``` -## Incorrect collection literal assignments are complained aobut +## Optional collection literal annotations are understood + +```toml +[environment] +python-version = "3.12" +``` + +```py +import typing + +a: list[int] | None = [1, 2, 3] +reveal_type(a) # revealed: list[int] + +b: list[int | str] | None = [1, 2, 3] +reveal_type(b) # revealed: list[int | str] + +c: typing.List[int] | None = [1, 2, 3] +reveal_type(c) # revealed: list[int] + +d: list[typing.Any] | None = [] +reveal_type(d) # revealed: list[Any] + +e: set[int] | None = {1, 2, 3} +reveal_type(e) # revealed: set[int] + +f: set[int | str] | None = {1, 2, 3} +reveal_type(f) # revealed: set[int | str] + +g: typing.Set[int] | None = {1, 2, 3} +reveal_type(g) # revealed: set[int] + +h: list[list[int]] | None = [[], [42]] +reveal_type(h) # revealed: list[list[int]] + +i: list[typing.Any] | None = [1, 2, "3", ([4],)] +reveal_type(i) # revealed: list[Any | int | str | tuple[list[Unknown | int]]] + +j: list[tuple[str | int, ...]] | None = [(1, 2), ("foo", "bar"), ()] +reveal_type(j) # revealed: list[tuple[str | int, ...]] + +k: list[tuple[list[int], ...]] | None = [([],), ([1, 2], [3, 4]), ([5], [6], [7])] +reveal_type(k) # revealed: list[tuple[list[int], ...]] + +l: tuple[list[int], *tuple[list[typing.Any], ...], list[str]] | None = ([1, 2, 3], [4, 5, 6], [7, 8, 9], ["10", "11", "12"]) +# TODO: this should be `tuple[list[int], list[Any | int], list[Any | int], list[str]]` +reveal_type(l) # revealed: tuple[list[Unknown | int], list[Unknown | int], list[Unknown | int], list[Unknown | str]] + +type IntList = list[int] + +m: IntList | None = [1, 2, 3] +reveal_type(m) # revealed: list[int] + +# TODO: this should type-check and avoid literal promotion +# error: [invalid-assignment] "Object of type `list[Unknown | int]` is not assignable to `list[Literal[1, 2, 3]] | None`" +n: list[typing.Literal[1, 2, 3]] | None = [1, 2, 3] +# TODO: this should be `list[Literal[1, 2, 3]]` at this scope +reveal_type(n) # revealed: list[Literal[1, 2, 3]] | None + +# TODO: this should type-check and avoid literal promotion +# error: [invalid-assignment] "Object of type `list[Unknown | str]` is not assignable to `list[LiteralString] | None`" +o: list[typing.LiteralString] | None = ["a", "b", "c"] +# TODO: this should be `list[LiteralString]` at this scope +reveal_type(o) # revealed: list[LiteralString] | None + +p: dict[int, int] | None = {} +reveal_type(p) # revealed: dict[int, int] + +q: dict[int | str, int] | None = {1: 1, 2: 2, 3: 3} +reveal_type(q) # revealed: dict[int | str, int] + +r: dict[int | str, int | str] | None = {1: 1, 2: 2, 3: 3} +reveal_type(r) # revealed: dict[int | str, int | str] +``` + +## Incorrect collection literal assignments are complained about ```py # error: [invalid-assignment] "Object of type `list[Unknown | int]` is not assignable to `list[str]`" diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index 9d0e04bbba7d0b..9f8000727434ac 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -1140,6 +1140,15 @@ impl<'db> Type<'db> { if yes { self.negate(db) } else { *self } } + /// Remove the union elements that are not related to `target`. + pub(crate) fn filter_disjoint_elements(self, db: &'db dyn Db, target: Type<'db>) -> Type<'db> { + if let Type::Union(union) = self { + union.filter(db, |elem| !elem.is_disjoint_from(db, target)) + } else { + self + } + } + /// Returns the fallback instance type that a literal is an instance of, or `None` if the type /// is not a literal. pub(crate) fn literal_fallback_instance(self, db: &'db dyn Db) -> Option> { diff --git a/crates/ty_python_semantic/src/types/generics.rs b/crates/ty_python_semantic/src/types/generics.rs index 455fc49c528d9f..1cd51235f3d713 100644 --- a/crates/ty_python_semantic/src/types/generics.rs +++ b/crates/ty_python_semantic/src/types/generics.rs @@ -1159,13 +1159,9 @@ impl<'db> SpecializationBuilder<'db> { return Ok(()); } - if let Type::Union(union) = actual { - // For example, if `formal` is `list[T]` and `actual` is `list[int] | None`, we want to specialize `T` to `int`. - // So, here we remove the union elements that are not related to `formal`. - actual = union.filter(self.db, |actual_elem| { - !actual_elem.is_disjoint_from(self.db, formal) - }); - } + // For example, if `formal` is `list[T]` and `actual` is `list[int] | None`, we want to specialize `T` to `int`. + // So, here we remove the union elements that are not related to `formal`. + actual = actual.filter_disjoint_elements(self.db, formal); match (formal, actual) { (Type::Union(_), Type::Union(_)) => { diff --git a/crates/ty_python_semantic/src/types/infer.rs b/crates/ty_python_semantic/src/types/infer.rs index 96da8aaaea37af..ee6cd8a748d6f7 100644 --- a/crates/ty_python_semantic/src/types/infer.rs +++ b/crates/ty_python_semantic/src/types/infer.rs @@ -390,6 +390,12 @@ impl<'db> TypeContext<'db> { self.annotation .and_then(|ty| ty.known_specialization(known_class, db)) } + + pub(crate) fn map_annotation(self, f: impl FnOnce(Type<'db>) -> Type<'db>) -> Self { + Self { + annotation: self.annotation.map(f), + } + } } /// Returns the statically-known truthiness of a given expression. diff --git a/crates/ty_python_semantic/src/types/infer/builder.rs b/crates/ty_python_semantic/src/types/infer/builder.rs index 3695dc5a73a88e..5c1be5b96907ee 100644 --- a/crates/ty_python_semantic/src/types/infer/builder.rs +++ b/crates/ty_python_semantic/src/types/infer/builder.rs @@ -5560,6 +5560,13 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { panic!("Typeshed should always have a `{name}` class in `builtins.pyi`") }); + let tcx = tcx.map_annotation(|annotation| { + // Remove any union elements of `annotation` that are not related to `collection_ty`. + // e.g. `annotation: list[int] | None => list[int]` if `collection_ty: list` + let collection_ty = collection_class.to_instance(self.db()); + annotation.filter_disjoint_elements(self.db(), collection_ty) + }); + // Extract the annotated type of `T`, if provided. let annotated_elt_tys = tcx .known_specialization(collection_class, self.db()) From 9507d7274b82567cd000bbfaefb4412251165ffb Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Tue, 7 Oct 2025 01:56:10 +0900 Subject: [PATCH 19/26] don't set `generic_context` on the `Signature` returned by `OverloadLiteral::raw_signature` --- crates/ty_python_semantic/src/types/function.rs | 5 +---- crates/ty_python_semantic/src/types/infer/builder.rs | 1 + crates/ty_python_semantic/src/types/signatures.rs | 11 ++--------- 3 files changed, 4 insertions(+), 13 deletions(-) diff --git a/crates/ty_python_semantic/src/types/function.rs b/crates/ty_python_semantic/src/types/function.rs index 50e861494d6674..513022135b8ff8 100644 --- a/crates/ty_python_semantic/src/types/function.rs +++ b/crates/ty_python_semantic/src/types/function.rs @@ -370,6 +370,7 @@ impl<'db> OverloadLiteral<'db> { /// Typed internally-visible "raw" signature for this function. /// That is, type variables in parameter types and the return type remain non-inferable, + /// the generic context is not set, /// and the return types of async functions are not wrapped in `CoroutineType[...]`. /// /// ## Warning @@ -440,9 +441,6 @@ impl<'db> OverloadLiteral<'db> { let function_stmt_node = scope.node(db).expect_function().node(&module); let definition = self.definition(db); let index = semantic_index(db, scope.file(db)); - let pep695_ctx = function_stmt_node.type_params.as_ref().map(|type_params| { - GenericContext::from_type_params(db, index, definition, type_params) - }); let file_scope_id = scope.file_scope_id(db); let has_implicitly_positional_first_parameter = has_implicitly_positional_only_first_param( @@ -455,7 +453,6 @@ impl<'db> OverloadLiteral<'db> { Signature::from_function( db, - pep695_ctx, definition, function_stmt_node, has_implicitly_positional_first_parameter, diff --git a/crates/ty_python_semantic/src/types/infer/builder.rs b/crates/ty_python_semantic/src/types/infer/builder.rs index 5c1be5b96907ee..9e6c2caa4d3a5a 100644 --- a/crates/ty_python_semantic/src/types/infer/builder.rs +++ b/crates/ty_python_semantic/src/types/infer/builder.rs @@ -4760,6 +4760,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { // When inferring expressions within a function body, // the expected type passed should be the "raw" type, // i.e. type variables in the return type are non-inferable, + // the generic context is not set, // and the return types of async functions are not wrapped in `CoroutineType[...]`. TypeContext::new(func.last_definition_raw_signature(self.db()).return_ty) }) diff --git a/crates/ty_python_semantic/src/types/signatures.rs b/crates/ty_python_semantic/src/types/signatures.rs index 5efb22b88db48c..d839720d83fd76 100644 --- a/crates/ty_python_semantic/src/types/signatures.rs +++ b/crates/ty_python_semantic/src/types/signatures.rs @@ -386,9 +386,9 @@ impl<'db> Signature<'db> { } /// Return a typed signature from a function definition. + /// Note that `Signature::generic_context` is not set here. pub(super) fn from_function( db: &'db dyn Db, - pep695_generic_context: Option>, definition: Definition<'db>, function_node: &ast::StmtFunctionDef, has_implicitly_positional_first_parameter: bool, @@ -403,16 +403,9 @@ impl<'db> Signature<'db> { .returns .as_ref() .map(|returns| definition_expression_type(db, definition, returns.as_ref())); - let legacy_generic_context = - GenericContext::from_function_params(db, definition, ¶meters, return_ty); - let full_generic_context = GenericContext::merge_pep695_and_legacy( - db, - pep695_generic_context, - legacy_generic_context, - ); Self { - generic_context: full_generic_context, + generic_context: None, definition: Some(definition), parameters, return_ty, From bd5a465c30addee10a2ddd93589885f0055fd50e Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Tue, 7 Oct 2025 02:12:19 +0900 Subject: [PATCH 20/26] Revert "don't set `generic_context` on the `Signature` returned by `OverloadLiteral::raw_signature`" This reverts commit 9507d7274b82567cd000bbfaefb4412251165ffb. --- crates/ty_python_semantic/src/types/function.rs | 5 ++++- crates/ty_python_semantic/src/types/infer/builder.rs | 1 - crates/ty_python_semantic/src/types/signatures.rs | 11 +++++++++-- 3 files changed, 13 insertions(+), 4 deletions(-) diff --git a/crates/ty_python_semantic/src/types/function.rs b/crates/ty_python_semantic/src/types/function.rs index 513022135b8ff8..50e861494d6674 100644 --- a/crates/ty_python_semantic/src/types/function.rs +++ b/crates/ty_python_semantic/src/types/function.rs @@ -370,7 +370,6 @@ impl<'db> OverloadLiteral<'db> { /// Typed internally-visible "raw" signature for this function. /// That is, type variables in parameter types and the return type remain non-inferable, - /// the generic context is not set, /// and the return types of async functions are not wrapped in `CoroutineType[...]`. /// /// ## Warning @@ -441,6 +440,9 @@ impl<'db> OverloadLiteral<'db> { let function_stmt_node = scope.node(db).expect_function().node(&module); let definition = self.definition(db); let index = semantic_index(db, scope.file(db)); + let pep695_ctx = function_stmt_node.type_params.as_ref().map(|type_params| { + GenericContext::from_type_params(db, index, definition, type_params) + }); let file_scope_id = scope.file_scope_id(db); let has_implicitly_positional_first_parameter = has_implicitly_positional_only_first_param( @@ -453,6 +455,7 @@ impl<'db> OverloadLiteral<'db> { Signature::from_function( db, + pep695_ctx, definition, function_stmt_node, has_implicitly_positional_first_parameter, diff --git a/crates/ty_python_semantic/src/types/infer/builder.rs b/crates/ty_python_semantic/src/types/infer/builder.rs index 9e6c2caa4d3a5a..5c1be5b96907ee 100644 --- a/crates/ty_python_semantic/src/types/infer/builder.rs +++ b/crates/ty_python_semantic/src/types/infer/builder.rs @@ -4760,7 +4760,6 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { // When inferring expressions within a function body, // the expected type passed should be the "raw" type, // i.e. type variables in the return type are non-inferable, - // the generic context is not set, // and the return types of async functions are not wrapped in `CoroutineType[...]`. TypeContext::new(func.last_definition_raw_signature(self.db()).return_ty) }) diff --git a/crates/ty_python_semantic/src/types/signatures.rs b/crates/ty_python_semantic/src/types/signatures.rs index d839720d83fd76..5efb22b88db48c 100644 --- a/crates/ty_python_semantic/src/types/signatures.rs +++ b/crates/ty_python_semantic/src/types/signatures.rs @@ -386,9 +386,9 @@ impl<'db> Signature<'db> { } /// Return a typed signature from a function definition. - /// Note that `Signature::generic_context` is not set here. pub(super) fn from_function( db: &'db dyn Db, + pep695_generic_context: Option>, definition: Definition<'db>, function_node: &ast::StmtFunctionDef, has_implicitly_positional_first_parameter: bool, @@ -403,9 +403,16 @@ impl<'db> Signature<'db> { .returns .as_ref() .map(|returns| definition_expression_type(db, definition, returns.as_ref())); + let legacy_generic_context = + GenericContext::from_function_params(db, definition, ¶meters, return_ty); + let full_generic_context = GenericContext::merge_pep695_and_legacy( + db, + pep695_generic_context, + legacy_generic_context, + ); Self { - generic_context: None, + generic_context: full_generic_context, definition: Some(definition), parameters, return_ty, From 8834fec17728a1531821309ed2fa89c311fce242 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Tue, 7 Oct 2025 02:24:50 +0900 Subject: [PATCH 21/26] Update function.rs --- crates/ty_python_semantic/src/types/function.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/crates/ty_python_semantic/src/types/function.rs b/crates/ty_python_semantic/src/types/function.rs index 50e861494d6674..2c1d0991b2222d 100644 --- a/crates/ty_python_semantic/src/types/function.rs +++ b/crates/ty_python_semantic/src/types/function.rs @@ -362,6 +362,8 @@ impl<'db> OverloadLiteral<'db> { signature.parameters(), signature.return_ty, ); + // We need to update `signature.generic_context` here, + // because type variables in `GenericContext::variables` are still non-inferable. signature.generic_context = GenericContext::merge_pep695_and_legacy(db, pep695_ctx, legacy_ctx); From d1e9455783d4d8cc3c481b712a67e70b60cc839e Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Wed, 8 Oct 2025 02:34:40 +0900 Subject: [PATCH 22/26] improve specialization between unions --- .../mdtest/generics/legacy/functions.md | 28 +++++++++++++++-- .../mdtest/generics/pep695/functions.md | 28 +++++++++++++++-- .../ty_python_semantic/src/types/generics.rs | 31 +++++++++++++++++-- 3 files changed, 81 insertions(+), 6 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/generics/legacy/functions.md b/crates/ty_python_semantic/resources/mdtest/generics/legacy/functions.md index 0996328e6f808e..1ed60127ab4a87 100644 --- a/crates/ty_python_semantic/resources/mdtest/generics/legacy/functions.md +++ b/crates/ty_python_semantic/resources/mdtest/generics/legacy/functions.md @@ -322,11 +322,35 @@ def union_param(x: T | None) -> T: reveal_type(union_param("a")) # revealed: Literal["a"] reveal_type(union_param(1)) # revealed: Literal[1] +# TODO: it would be better if this were `Never` reveal_type(union_param(None)) # revealed: Unknown def _(x: int | None): - # TODO: should be `int` - reveal_type(union_param(x)) # revealed: Unknown + reveal_type(union_param(x)) # revealed: int +``` + +```py +def union_param2(x: T | int | str) -> T: + if isinstance(x, (int, str)): + raise ValueError + return x + +# TODO: it would be better if this were `Never` +reveal_type(union_param2("a")) # revealed: Unknown +# TODO: it would be better if this were `Never` +reveal_type(union_param2(1)) # revealed: Unknown +reveal_type(union_param2(None)) # revealed: None + +def _( + a: None | int | str, + b: None | int, + c: None | str, + d: list[int] | None | int | str, +): + reveal_type(union_param2(a)) # revealed: None + reveal_type(union_param2(b)) # revealed: None + reveal_type(union_param2(c)) # revealed: None + reveal_type(union_param2(d)) # revealed: list[int] | None ``` ```py diff --git a/crates/ty_python_semantic/resources/mdtest/generics/pep695/functions.md b/crates/ty_python_semantic/resources/mdtest/generics/pep695/functions.md index 4d46c0c95a3d61..53cb740deab276 100644 --- a/crates/ty_python_semantic/resources/mdtest/generics/pep695/functions.md +++ b/crates/ty_python_semantic/resources/mdtest/generics/pep695/functions.md @@ -285,11 +285,35 @@ def union_param[T](x: T | None) -> T: reveal_type(union_param("a")) # revealed: Literal["a"] reveal_type(union_param(1)) # revealed: Literal[1] +# TODO: it would be better if this were `Never` reveal_type(union_param(None)) # revealed: Unknown def _(x: int | None): - # TODO: should be `int` - reveal_type(union_param(x)) # revealed: Unknown + reveal_type(union_param(x)) # revealed: int +``` + +```py +def union_param2[T](x: T | int | str) -> T: + if isinstance(x, (int, str)): + raise ValueError + return x + +# TODO: it would be better if this were `Never` +reveal_type(union_param2("a")) # revealed: Unknown +# TODO: it would be better if this were `Never` +reveal_type(union_param2(1)) # revealed: Unknown +reveal_type(union_param2(None)) # revealed: None + +def _( + a: None | int | str, + b: None | int, + c: None | str, + d: list[int] | None | int | str, +): + reveal_type(union_param2(a)) # revealed: None + reveal_type(union_param2(b)) # revealed: None + reveal_type(union_param2(c)) # revealed: None + reveal_type(union_param2(d)) # revealed: list[int] | None ``` ```py diff --git a/crates/ty_python_semantic/src/types/generics.rs b/crates/ty_python_semantic/src/types/generics.rs index 1cd51235f3d713..e7af87e4cd5bb7 100644 --- a/crates/ty_python_semantic/src/types/generics.rs +++ b/crates/ty_python_semantic/src/types/generics.rs @@ -17,7 +17,7 @@ use crate::types::{ ApplyTypeMappingVisitor, BoundTypeVarInstance, ClassLiteral, FindLegacyTypeVarsVisitor, HasRelationToVisitor, IsEquivalentVisitor, KnownClass, KnownInstanceType, MaterializationKind, NormalizedVisitor, Type, TypeMapping, TypeRelation, TypeVarBoundOrConstraints, TypeVarInstance, - TypeVarKind, TypeVarVariance, UnionType, binding_type, declaration_type, + TypeVarKind, TypeVarVariance, UnionType, any_over_type, binding_type, declaration_type, }; use crate::{Db, FxOrderMap, FxOrderSet}; @@ -1164,10 +1164,37 @@ impl<'db> SpecializationBuilder<'db> { actual = actual.filter_disjoint_elements(self.db, formal); match (formal, actual) { - (Type::Union(_), Type::Union(_)) => { + (Type::Union(formal), Type::Union(actual)) => { // TODO: We need to infer specializations appropriately. // e.g. // `formal: list[T] | T | U, actual: V | int | list[V]` => `T = V, U = int` + + let has_typevar = + |ty| any_over_type(self.db, ty, &|ty| matches!(ty, Type::TypeVar(_)), false); + + // Here we do a simplified operation, we use concrete types (types that don't contain type variables) in `formal` + // to reduce the elements of `actual`. When `formal` or `actual` is no longer a union type, we can fall back to the cases below. + let mut actual = actual; + for concrete in formal + .elements(self.db) + .iter() + .filter(|ty| !has_typevar(**ty)) + { + match actual.filter(self.db, |ty| !ty.is_assignable_to(self.db, *concrete)) { + Type::Union(union) => { + actual = union; + } + other => { + let formal = formal.filter(self.db, |ty| has_typevar(**ty)); + self.infer(formal, other)?; + return Ok(()); + } + } + } + let formal = formal.filter(self.db, |ty| has_typevar(**ty)); + if !matches!(formal, Type::Union(_)) { + self.infer(formal, Type::Union(actual))?; + } } (Type::Union(formal), _) => { // TODO: We haven't implemented a full unification solver yet. If typevars appear From b5fddbcc345394a57545d14011f4d4a4a5524f59 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Wed, 8 Oct 2025 02:37:35 +0900 Subject: [PATCH 23/26] Revert "improve specialization between unions" This reverts commit d1e9455783d4d8cc3c481b712a67e70b60cc839e. --- .../mdtest/generics/legacy/functions.md | 28 ++--------------- .../mdtest/generics/pep695/functions.md | 28 ++--------------- .../ty_python_semantic/src/types/generics.rs | 31 ++----------------- 3 files changed, 6 insertions(+), 81 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/generics/legacy/functions.md b/crates/ty_python_semantic/resources/mdtest/generics/legacy/functions.md index 1ed60127ab4a87..0996328e6f808e 100644 --- a/crates/ty_python_semantic/resources/mdtest/generics/legacy/functions.md +++ b/crates/ty_python_semantic/resources/mdtest/generics/legacy/functions.md @@ -322,35 +322,11 @@ def union_param(x: T | None) -> T: reveal_type(union_param("a")) # revealed: Literal["a"] reveal_type(union_param(1)) # revealed: Literal[1] -# TODO: it would be better if this were `Never` reveal_type(union_param(None)) # revealed: Unknown def _(x: int | None): - reveal_type(union_param(x)) # revealed: int -``` - -```py -def union_param2(x: T | int | str) -> T: - if isinstance(x, (int, str)): - raise ValueError - return x - -# TODO: it would be better if this were `Never` -reveal_type(union_param2("a")) # revealed: Unknown -# TODO: it would be better if this were `Never` -reveal_type(union_param2(1)) # revealed: Unknown -reveal_type(union_param2(None)) # revealed: None - -def _( - a: None | int | str, - b: None | int, - c: None | str, - d: list[int] | None | int | str, -): - reveal_type(union_param2(a)) # revealed: None - reveal_type(union_param2(b)) # revealed: None - reveal_type(union_param2(c)) # revealed: None - reveal_type(union_param2(d)) # revealed: list[int] | None + # TODO: should be `int` + reveal_type(union_param(x)) # revealed: Unknown ``` ```py diff --git a/crates/ty_python_semantic/resources/mdtest/generics/pep695/functions.md b/crates/ty_python_semantic/resources/mdtest/generics/pep695/functions.md index 53cb740deab276..4d46c0c95a3d61 100644 --- a/crates/ty_python_semantic/resources/mdtest/generics/pep695/functions.md +++ b/crates/ty_python_semantic/resources/mdtest/generics/pep695/functions.md @@ -285,35 +285,11 @@ def union_param[T](x: T | None) -> T: reveal_type(union_param("a")) # revealed: Literal["a"] reveal_type(union_param(1)) # revealed: Literal[1] -# TODO: it would be better if this were `Never` reveal_type(union_param(None)) # revealed: Unknown def _(x: int | None): - reveal_type(union_param(x)) # revealed: int -``` - -```py -def union_param2[T](x: T | int | str) -> T: - if isinstance(x, (int, str)): - raise ValueError - return x - -# TODO: it would be better if this were `Never` -reveal_type(union_param2("a")) # revealed: Unknown -# TODO: it would be better if this were `Never` -reveal_type(union_param2(1)) # revealed: Unknown -reveal_type(union_param2(None)) # revealed: None - -def _( - a: None | int | str, - b: None | int, - c: None | str, - d: list[int] | None | int | str, -): - reveal_type(union_param2(a)) # revealed: None - reveal_type(union_param2(b)) # revealed: None - reveal_type(union_param2(c)) # revealed: None - reveal_type(union_param2(d)) # revealed: list[int] | None + # TODO: should be `int` + reveal_type(union_param(x)) # revealed: Unknown ``` ```py diff --git a/crates/ty_python_semantic/src/types/generics.rs b/crates/ty_python_semantic/src/types/generics.rs index e7af87e4cd5bb7..1cd51235f3d713 100644 --- a/crates/ty_python_semantic/src/types/generics.rs +++ b/crates/ty_python_semantic/src/types/generics.rs @@ -17,7 +17,7 @@ use crate::types::{ ApplyTypeMappingVisitor, BoundTypeVarInstance, ClassLiteral, FindLegacyTypeVarsVisitor, HasRelationToVisitor, IsEquivalentVisitor, KnownClass, KnownInstanceType, MaterializationKind, NormalizedVisitor, Type, TypeMapping, TypeRelation, TypeVarBoundOrConstraints, TypeVarInstance, - TypeVarKind, TypeVarVariance, UnionType, any_over_type, binding_type, declaration_type, + TypeVarKind, TypeVarVariance, UnionType, binding_type, declaration_type, }; use crate::{Db, FxOrderMap, FxOrderSet}; @@ -1164,37 +1164,10 @@ impl<'db> SpecializationBuilder<'db> { actual = actual.filter_disjoint_elements(self.db, formal); match (formal, actual) { - (Type::Union(formal), Type::Union(actual)) => { + (Type::Union(_), Type::Union(_)) => { // TODO: We need to infer specializations appropriately. // e.g. // `formal: list[T] | T | U, actual: V | int | list[V]` => `T = V, U = int` - - let has_typevar = - |ty| any_over_type(self.db, ty, &|ty| matches!(ty, Type::TypeVar(_)), false); - - // Here we do a simplified operation, we use concrete types (types that don't contain type variables) in `formal` - // to reduce the elements of `actual`. When `formal` or `actual` is no longer a union type, we can fall back to the cases below. - let mut actual = actual; - for concrete in formal - .elements(self.db) - .iter() - .filter(|ty| !has_typevar(**ty)) - { - match actual.filter(self.db, |ty| !ty.is_assignable_to(self.db, *concrete)) { - Type::Union(union) => { - actual = union; - } - other => { - let formal = formal.filter(self.db, |ty| has_typevar(**ty)); - self.infer(formal, other)?; - return Ok(()); - } - } - } - let formal = formal.filter(self.db, |ty| has_typevar(**ty)); - if !matches!(formal, Type::Union(_)) { - self.infer(formal, Type::Union(actual))?; - } } (Type::Union(formal), _) => { // TODO: We haven't implemented a full unification solver yet. If typevars appear From 97c065df494fa8b11f59da0b32556fb3ce33c919 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Wed, 8 Oct 2025 03:57:32 +0900 Subject: [PATCH 24/26] add `TypeDict` test cases --- .../resources/mdtest/bidirectional.md | 7 +++++++ .../resources/mdtest/typed_dict.md | 12 +++++++++++- 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/crates/ty_python_semantic/resources/mdtest/bidirectional.md b/crates/ty_python_semantic/resources/mdtest/bidirectional.md index d8171b25f1024b..3485304b6b4224 100644 --- a/crates/ty_python_semantic/resources/mdtest/bidirectional.md +++ b/crates/ty_python_semantic/resources/mdtest/bidirectional.md @@ -70,6 +70,13 @@ d3: dict[str, int] = {"x": 1} reveal_type(d1) # revealed: dict[Unknown | str, Unknown | int] reveal_type(d2) # revealed: TD reveal_type(d3) # revealed: dict[str, int] + +def _() -> TD: + return {"x": 1} + +def _() -> TD: + # error: [missing-typed-dict-key] "Missing required key 'x' in TypedDict `TD` constructor" + return {} ``` ## Propagating return type annotation diff --git a/crates/ty_python_semantic/resources/mdtest/typed_dict.md b/crates/ty_python_semantic/resources/mdtest/typed_dict.md index dc3734f368b9d2..a6d97632478880 100644 --- a/crates/ty_python_semantic/resources/mdtest/typed_dict.md +++ b/crates/ty_python_semantic/resources/mdtest/typed_dict.md @@ -284,7 +284,7 @@ qualifiers override the class-level `total` setting, which sets the default (`to all keys are required by default, `total=False` means that all keys are non-required by default): ```py -from typing_extensions import TypedDict, Required, NotRequired +from typing_extensions import TypedDict, Required, NotRequired, Final # total=False by default, but id is explicitly Required class Message(TypedDict, total=False): @@ -298,10 +298,20 @@ class User(TypedDict): email: Required[str] # Explicitly required (redundant here) bio: NotRequired[str] # Optional despite total=True +ID: Final = "id" + # Valid Message constructions msg1 = Message(id=1) # id required, content optional msg2 = Message(id=2, content="Hello") # both provided msg3 = Message(id=3, timestamp="2024-01-01") # id required, timestamp optional +msg4: Message = {"id": 4} # id required, content optional +# TODO: no error +# error: [missing-typed-dict-key] +msg5: Message = {ID: 5} # id required, content optional + +def msg() -> Message: + # TODO: no error + return {ID: 1} # error: [missing-typed-dict-key] # Valid User constructions user1 = User(name="Alice", email="alice@example.com") # required fields From 9cea9030ed13b6efa905827a31f4dc37129d6801 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Fri, 10 Oct 2025 14:26:51 +0900 Subject: [PATCH 25/26] update mdtest --- .../resources/mdtest/assignment/annotations.md | 10 ++-------- .../ty_python_semantic/resources/mdtest/typed_dict.md | 5 +---- 2 files changed, 3 insertions(+), 12 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/assignment/annotations.md b/crates/ty_python_semantic/resources/mdtest/assignment/annotations.md index c6303b2c31a9cf..ada7ee9aef49c0 100644 --- a/crates/ty_python_semantic/resources/mdtest/assignment/annotations.md +++ b/crates/ty_python_semantic/resources/mdtest/assignment/annotations.md @@ -198,17 +198,11 @@ type IntList = list[int] m: IntList | None = [1, 2, 3] reveal_type(m) # revealed: list[int] -# TODO: this should type-check and avoid literal promotion -# error: [invalid-assignment] "Object of type `list[Unknown | int]` is not assignable to `list[Literal[1, 2, 3]] | None`" n: list[typing.Literal[1, 2, 3]] | None = [1, 2, 3] -# TODO: this should be `list[Literal[1, 2, 3]]` at this scope -reveal_type(n) # revealed: list[Literal[1, 2, 3]] | None +reveal_type(n) # revealed: list[Literal[1, 2, 3]] -# TODO: this should type-check and avoid literal promotion -# error: [invalid-assignment] "Object of type `list[Unknown | str]` is not assignable to `list[LiteralString] | None`" o: list[typing.LiteralString] | None = ["a", "b", "c"] -# TODO: this should be `list[LiteralString]` at this scope -reveal_type(o) # revealed: list[LiteralString] | None +reveal_type(o) # revealed: list[LiteralString] p: dict[int, int] | None = {} reveal_type(p) # revealed: dict[int, int] diff --git a/crates/ty_python_semantic/resources/mdtest/typed_dict.md b/crates/ty_python_semantic/resources/mdtest/typed_dict.md index ce68781db2d70f..8e53e694653b38 100644 --- a/crates/ty_python_semantic/resources/mdtest/typed_dict.md +++ b/crates/ty_python_semantic/resources/mdtest/typed_dict.md @@ -345,13 +345,10 @@ msg1 = Message(id=1) # id required, content optional msg2 = Message(id=2, content="Hello") # both provided msg3 = Message(id=3, timestamp="2024-01-01") # id required, timestamp optional msg4: Message = {"id": 4} # id required, content optional -# TODO: no error -# error: [missing-typed-dict-key] msg5: Message = {ID: 5} # id required, content optional def msg() -> Message: - # TODO: no error - return {ID: 1} # error: [missing-typed-dict-key] + return {ID: 1} # Valid User constructions user1 = User(name="Alice", email="alice@example.com") # required fields From e7b5c14c253fbe737d62efda4513de8fb3db7518 Mon Sep 17 00:00:00 2001 From: Ibraheem Ahmed Date: Fri, 10 Oct 2025 20:33:59 -0400 Subject: [PATCH 26/26] update tests --- crates/ty_python_semantic/resources/mdtest/typed_dict.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/typed_dict.md b/crates/ty_python_semantic/resources/mdtest/typed_dict.md index 16fe95b294285f..8fc8b2fcb616e5 100644 --- a/crates/ty_python_semantic/resources/mdtest/typed_dict.md +++ b/crates/ty_python_semantic/resources/mdtest/typed_dict.md @@ -125,9 +125,10 @@ def homogeneous_list[T](*args: T) -> list[T]: reveal_type(homogeneous_list(1, 2, 3)) # revealed: list[Literal[1, 2, 3]] plot2: Plot = {"y": homogeneous_list(1, 2, 3), "x": None} reveal_type(plot2["y"]) # revealed: list[int] -# TODO: no error -# error: [invalid-argument-type] + plot3: Plot = {"y": homogeneous_list(1, 2, 3), "x": homogeneous_list(1, 2, 3)} +reveal_type(plot3["y"]) # revealed: list[int] +reveal_type(plot3["x"]) # revealed: list[int] | None Y = "y" X = "x"