diff --git a/crates/ty_python_semantic/src/types/signatures.rs b/crates/ty_python_semantic/src/types/signatures.rs index ec1754fba48b8..2dfd2d9f91200 100644 --- a/crates/ty_python_semantic/src/types/signatures.rs +++ b/crates/ty_python_semantic/src/types/signatures.rs @@ -942,24 +942,15 @@ impl<'c, 'db> TypeRelationChecker<'_, 'c, 'db> { let target_is_single_paramspec = CallableSignature::signatures_is_single_paramspec(target_overloads); - // If either callable is a ParamSpec, the constraint set should bind the ParamSpec to - // the other callable's signature. We also need to compare the return types — for - // instance, to verify in `Callable[P, int]` that the return type is assignable to - // `int`, or in `Callable[P, T]` to bind `T` to the return type of the other callable. - match (source_is_single_paramspec, target_is_single_paramspec) { - (Some((source_tvar, source_return)), Some((target_tvar, target_return))) => { - let param_spec_matches = ConstraintSet::constrain_typevar( - db, - self.constraints, - source_tvar, - Type::TypeVar(target_tvar), - Type::TypeVar(target_tvar), - ); - let return_types_match = self.check_type_pair(db, source_return, target_return); - return param_spec_matches.and(db, self.constraints, || return_types_match); - } + // TODO: Adding proper support for overloads with ParamSpec will likely require some + // changes here. - (Some((source_tvar, source_return)), None) => { + // Only handle ParamSpec here when we still need the whole overload set. Once we're + // down to a single signature on both sides, let + // `TypeRelationChecker::check_signature_pair_inner` handle the ParamSpec binding + // instead. + match (source_is_single_paramspec, target_is_single_paramspec) { + (Some((source_tvar, source_return)), None) if target_overloads.len() > 1 => { let upper = Type::Callable(CallableType::new( db, CallableSignature::from_overloads(target_overloads.iter().map( @@ -991,7 +982,7 @@ impl<'c, 'db> TypeRelationChecker<'_, 'c, 'db> { return param_spec_matches.and(db, self.constraints, return_types_match); } - (None, Some((target_tvar, target_return))) => { + (None, Some((target_tvar, target_return))) if source_overloads.len() > 1 => { let lower = Type::Callable(CallableType::new( db, CallableSignature::from_overloads(source_overloads.iter().map( @@ -1023,22 +1014,29 @@ impl<'c, 'db> TypeRelationChecker<'_, 'c, 'db> { return param_spec_matches.and(db, self.constraints, return_types_match); } - (None, None) => {} + _ => {} } } match (source_overloads, target_overloads) { - ([self_signature], [other_signature]) => { + ([source_signature], [target_signature]) => { // Base case: both callable types contain a single signature. - self.check_signature_pair(db, self_signature, other_signature) + if self.relation.is_constraint_set_assignability() + && (source_signature.parameters.as_paramspec().is_some() + || target_signature.parameters.as_paramspec().is_some()) + { + self.check_signature_pair_inner(db, source_signature, target_signature) + } else { + self.check_signature_pair(db, source_signature, target_signature) + } } - // `self` is possibly overloaded while `other` is definitely not overloaded. - (_, [other_signature]) => { + // source is possibly overloaded while target is definitely not overloaded. + (_, [target_signature]) => { if let Some(aggregate_relation) = self.try_unary_overload_aggregate_relation( db, source_overloads, - other_signature, + target_signature, ) { return aggregate_relation; } @@ -1054,7 +1052,7 @@ impl<'c, 'db> TypeRelationChecker<'_, 'c, 'db> { }) } - // `self` is definitely not overloaded while `other` is possibly overloaded. + // source is definitely not overloaded while target is possibly overloaded. ([_], _) => { target_overloads .iter() @@ -1067,7 +1065,7 @@ impl<'c, 'db> TypeRelationChecker<'_, 'c, 'db> { }) } - // `self` is definitely overloaded while `other` is possibly overloaded. + // source is definitely overloaded while target is possibly overloaded. (_, _) => target_overloads .iter() .when_all(db, self.constraints, |target_signature| { @@ -1216,72 +1214,33 @@ impl<'c, 'db> TypeRelationChecker<'_, 'c, 'db> { .is_never_satisfied(db) }; - // Return types are covariant. - if !check_types(source.return_ty, target.return_ty) { - return result; - } - - // A gradual parameter list is a supertype of the "bottom" parameter list (*args: object, - // **kwargs: object). - if target.parameters.is_gradual() - && !source.parameters.is_top() - && source - .parameters - .variadic() - .is_some_and(|(_, param)| param.annotated_type().is_object()) - && source - .parameters - .keyword_variadic() - .is_some_and(|(_, param)| param.annotated_type().is_object()) - { - return self.always(); - } - - // The top signature is supertype of (and assignable from) all other signatures. It is a - // subtype of no signature except itself, and assignable only to the gradual signature. - if target.parameters.is_top() { - return self.always(); - } else if source.parameters.is_top() && !target.parameters.is_gradual() { - return self.never(); - } - - // If either of the parameter lists is gradual (`...`), then it is assignable to and from - // any other parameter list, but not a subtype or supertype of any other parameter list. - if source.parameters.is_gradual() || target.parameters.is_gradual() { - return match self.relation { - TypeRelation::Subtyping | TypeRelation::SubtypingAssuming => self.never(), - TypeRelation::Redundancy { .. } => result.intersect( - db, - self.constraints, - ConstraintSet::from_bool( - self.constraints, - source.parameters.is_gradual() && target.parameters.is_gradual(), - ), - ), - TypeRelation::Assignability | TypeRelation::ConstraintSetAssignability => result, - }; - } + // Avoid returning early after checking the return types in case there is a `ParamSpec` type + // variable in either signature to ensure that the `ParamSpec` binding is still applied even + // if the return types are incompatible. + let return_type_checks = check_types(source.return_ty, target.return_ty); if self.relation.is_constraint_set_assignability() { - let source_is_paramspec = source.parameters.as_paramspec(); - let target_is_paramspec = target.parameters.as_paramspec(); + let source_as_paramspec = source.parameters.as_paramspec(); + let target_as_paramspec = target.parameters.as_paramspec(); // If either signature is a ParamSpec, the constraint set should bind the ParamSpec to - // the other signature. - match (source_is_paramspec, target_is_paramspec) { - (Some(source_tvar), Some(target_tvar)) => { + // the other signature before the return-type and gradual/top fast paths can return + // early. We also need to compare the return types here so a return-type mismatch still + // preserves the inferred ParamSpec binding. + match (source_as_paramspec, target_as_paramspec) { + (Some(source_bound_typevar), Some(target_bound_typevar)) => { let param_spec_matches = ConstraintSet::constrain_typevar( db, self.constraints, - source_tvar, - Type::TypeVar(target_tvar), - Type::TypeVar(target_tvar), + source_bound_typevar, + Type::TypeVar(target_bound_typevar), + Type::TypeVar(target_bound_typevar), ); result.intersect(db, self.constraints, param_spec_matches); return result; } - (Some(source_tvar), None) => { + (Some(source_bound_typevar), None) => { let upper = Type::Callable(CallableType::new( db, CallableSignature::single(Signature::new_generic( @@ -1294,7 +1253,7 @@ impl<'c, 'db> TypeRelationChecker<'_, 'c, 'db> { let param_spec_matches = ConstraintSet::constrain_typevar( db, self.constraints, - source_tvar, + source_bound_typevar, Type::Never, upper, ); @@ -1302,7 +1261,7 @@ impl<'c, 'db> TypeRelationChecker<'_, 'c, 'db> { return result; } - (None, Some(target_tvar)) => { + (None, Some(target_bound_typevar)) => { let lower = Type::Callable(CallableType::new( db, CallableSignature::single(Signature::new_generic( @@ -1315,7 +1274,7 @@ impl<'c, 'db> TypeRelationChecker<'_, 'c, 'db> { let param_spec_matches = ConstraintSet::constrain_typevar( db, self.constraints, - target_tvar, + target_bound_typevar, lower, Type::object(), ); @@ -1327,6 +1286,51 @@ impl<'c, 'db> TypeRelationChecker<'_, 'c, 'db> { } } + if !return_type_checks { + return result; + } + + // A gradual parameter list is a supertype of the "bottom" parameter list (*args: object, + // **kwargs: object). + if target.parameters.is_gradual() + && !source.parameters.is_top() + && source + .parameters + .variadic() + .is_some_and(|(_, param)| param.annotated_type().is_object()) + && source + .parameters + .keyword_variadic() + .is_some_and(|(_, param)| param.annotated_type().is_object()) + { + return self.always(); + } + + // The top signature is supertype of (and assignable from) all other signatures. It is a + // subtype of no signature except itself, and assignable only to the gradual signature. + if target.parameters.is_top() { + return self.always(); + } else if source.parameters.is_top() && !target.parameters.is_gradual() { + return self.never(); + } + + // If either of the parameter lists is gradual (`...`), then it is assignable to and from + // any other parameter list, but not a subtype or supertype of any other parameter list. + if source.parameters.is_gradual() || target.parameters.is_gradual() { + return match self.relation { + TypeRelation::Subtyping | TypeRelation::SubtypingAssuming => self.never(), + TypeRelation::Redundancy { .. } => result.intersect( + db, + self.constraints, + ConstraintSet::from_bool( + self.constraints, + source.parameters.is_gradual() && target.parameters.is_gradual(), + ), + ), + TypeRelation::Assignability | TypeRelation::ConstraintSetAssignability => result, + }; + } + let mut parameters = ParametersZip { current_source: None, current_target: None,