From 1c4264db71fc907521eef45eb70a4bf512fea3d4 Mon Sep 17 00:00:00 2001 From: Valentin Iovene Date: Thu, 19 Feb 2026 15:21:07 +0100 Subject: [PATCH] [ty] support for `typing.Concatenate` --- .../resources/mdtest/annotations/callable.md | 3 +- .../resources/mdtest/final.md | 4 +- .../resources/mdtest/generics/concatenate.md | 124 ++++++++++++ .../mdtest/generics/pep695/paramspec.md | 3 - .../resources/mdtest/pep613_type_aliases.md | 6 +- crates/ty_python_semantic/src/types.rs | 2 +- .../ty_python_semantic/src/types/display.rs | 13 +- .../src/types/infer/builder.rs | 188 ++++++++++++------ .../types/infer/builder/type_expression.rs | 131 +++++++++++- .../src/types/signatures.rs | 163 ++++++++++++++- 10 files changed, 541 insertions(+), 96 deletions(-) create mode 100644 crates/ty_python_semantic/resources/mdtest/generics/concatenate.md diff --git a/crates/ty_python_semantic/resources/mdtest/annotations/callable.md b/crates/ty_python_semantic/resources/mdtest/annotations/callable.md index ad566f868f4abc..c017a0eb14747e 100644 --- a/crates/ty_python_semantic/resources/mdtest/annotations/callable.md +++ b/crates/ty_python_semantic/resources/mdtest/annotations/callable.md @@ -257,8 +257,7 @@ Using `Concatenate` as the first argument to `Callable`: from typing_extensions import Callable, Concatenate def _(c: Callable[Concatenate[int, str, ...], int]): - # TODO: Should reveal the correct signature - reveal_type(c) # revealed: (...) -> int + reveal_type(c) # revealed: (int, str, /, *args: Any, **kwargs: Any) -> int ``` And, as one of the parameter types: diff --git a/crates/ty_python_semantic/resources/mdtest/final.md b/crates/ty_python_semantic/resources/mdtest/final.md index 1b6081908f339b..af42fedb6cb7af 100644 --- a/crates/ty_python_semantic/resources/mdtest/final.md +++ b/crates/ty_python_semantic/resources/mdtest/final.md @@ -1302,7 +1302,9 @@ class Base(ABC): @abstractproperty # error: [deprecated] def value(self) -> int: return 0 - + # TODO: False positive: `Concatenate` in `classmethod.__init__` signature causes spurious + # invalid-argument-type when the type variables are not fully resolved. + # error: [invalid-argument-type] @abstractclassmethod # error: [deprecated] def make(cls) -> "Base": raise NotImplementedError diff --git a/crates/ty_python_semantic/resources/mdtest/generics/concatenate.md b/crates/ty_python_semantic/resources/mdtest/generics/concatenate.md new file mode 100644 index 00000000000000..c0ec59dc343f54 --- /dev/null +++ b/crates/ty_python_semantic/resources/mdtest/generics/concatenate.md @@ -0,0 +1,124 @@ +# `typing.Concatenate` + +`Concatenate` is used with `Callable` and `ParamSpec` to describe higher-order functions that add, +remove, or transform parameters of other callables. + +## Basic `Callable[Concatenate[..., ...], ...]` types + +### With ellipsis (gradual form) + +```py +from typing_extensions import Callable, Concatenate + +def _(c: Callable[Concatenate[int, ...], str]): + reveal_type(c) # revealed: (int, /, *args: Any, **kwargs: Any) -> str + +def _(c: Callable[Concatenate[int, str, ...], bool]): + reveal_type(c) # revealed: (int, str, /, *args: Any, **kwargs: Any) -> bool +``` + +### With `ParamSpec` + +```toml +[environment] +python-version = "3.12" +``` + +```py +from typing_extensions import Callable, Concatenate, ParamSpec + +P = ParamSpec("P") + +def _(c: Callable[Concatenate[int, P], str]): + reveal_type(c) # revealed: (int, /, *args: P@_.args, **kwargs: P@_.kwargs) -> str +``` + +## Decorator that strips a prefix parameter + +A common use case is decorators that strip the first parameter from a callable. + +```toml +[environment] +python-version = "3.12" +``` + +```py +from typing import Callable, reveal_type +from typing_extensions import Concatenate, ParamSpec + +P = ParamSpec("P") + +def with_request[**P, R](f: Callable[Concatenate[int, P], R]) -> Callable[P, R]: + def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: + return f(0, *args, **kwargs) + return wrapper + +@with_request +def handler(request: int, name: str) -> bool: + return True + +# The decorator strips the first `int` parameter +reveal_type(handler) # revealed: (name: str) -> bool + +# Calling without the stripped parameter should work +handler("test") +``` + +## Multiple prefix parameters + +```toml +[environment] +python-version = "3.12" +``` + +```py +from typing import Callable, reveal_type +from typing_extensions import Concatenate, ParamSpec + +P = ParamSpec("P") + +def add_two_params[**P, R]( + f: Callable[Concatenate[int, str, P], R], +) -> Callable[P, R]: + def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: + return f(0, "", *args, **kwargs) + return wrapper + +@add_two_params +def process(a: int, b: str, flag: bool) -> None: + pass + +reveal_type(process) # revealed: (flag: bool) -> None + +process(True) +``` + +## Assignability of `Concatenate` gradual forms + +When both sides of an assignment use `Concatenate[T, ...]`, the prefix parameters must be +compatible. The gradual tail (`...`) still allows assignability for the remaining parameters. + +```py +from typing_extensions import Callable, Concatenate + +def _( + x: Callable[Concatenate[int, ...], None], + y: Callable[Concatenate[str, ...], None], + same: Callable[Concatenate[int, ...], None], + gradual: Callable[..., None], + multi_self: Callable[Concatenate[int, str, ...], None], + multi_other: Callable[Concatenate[str, int, ...], None], +): + # Same prefix types: assignable + x = same + + # Different prefix types: not assignable + x = y # error: [invalid-assignment] + + # Swapped multi-prefix types: not assignable + multi_self = multi_other # error: [invalid-assignment] + + # Pure gradual is assignable to/from Concatenate gradual + x = gradual + gradual = x +``` diff --git a/crates/ty_python_semantic/resources/mdtest/generics/pep695/paramspec.md b/crates/ty_python_semantic/resources/mdtest/generics/pep695/paramspec.md index fc818abfdca3d0..62a45051c39d64 100644 --- a/crates/ty_python_semantic/resources/mdtest/generics/pep695/paramspec.md +++ b/crates/ty_python_semantic/resources/mdtest/generics/pep695/paramspec.md @@ -1081,8 +1081,5 @@ class Factory[**P](Protocol): def call_factory[**P](ctr: Factory[P], *args: P.args, **kwargs: P.kwargs) -> int: return ctr("", *args, **kwargs) -# TODO: This should be OK - P should be inferred as [] since my_factory only has `arg: str` -# which matches the prefix. Currently this is a false positive. -# error: [invalid-argument-type] call_factory(my_factory) ``` diff --git a/crates/ty_python_semantic/resources/mdtest/pep613_type_aliases.md b/crates/ty_python_semantic/resources/mdtest/pep613_type_aliases.md index 3c1f8831499e0d..9f61bf048be70f 100644 --- a/crates/ty_python_semantic/resources/mdtest/pep613_type_aliases.md +++ b/crates/ty_python_semantic/resources/mdtest/pep613_type_aliases.md @@ -237,8 +237,7 @@ from typing_extensions import Callable, Concatenate, TypeAliasType MyAlias4: TypeAlias = Callable[Concatenate[dict[str, T], ...], list[U]] def _(c: MyAlias4[int, str]): - # TODO: should be (int, / ...) -> str - reveal_type(c) # revealed: Unknown + reveal_type(c) # revealed: (dict[str, int], /, *args: Any, **kwargs: Any) -> list[str] T = TypeVar("T") @@ -270,8 +269,7 @@ def _(x: ListOrDict[int]): MyAlias7: TypeAlias = Callable[Concatenate[T, ...], None] def _(c: MyAlias7[int]): - # TODO: should be (int, / ...) -> None - reveal_type(c) # revealed: Unknown + reveal_type(c) # revealed: (int, /, *args: Any, **kwargs: Any) -> None ``` ## Imported diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index 9fea5bf06d481a..709068715f933f 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -73,7 +73,7 @@ pub(crate) use crate::types::narrow::{ infer_narrowing_constraint, }; use crate::types::newtype::NewType; -pub(crate) use crate::types::signatures::{Parameter, Parameters}; +pub(crate) use crate::types::signatures::{ConcatenateTail, Parameter, Parameters}; use crate::types::signatures::{ParameterForm, walk_signature}; use crate::types::tuple::{Tuple, TupleSpec, TupleSpecBuilder}; use crate::types::typed_dict::TypedDictField; diff --git a/crates/ty_python_semantic/src/types/display.rs b/crates/ty_python_semantic/src/types/display.rs index 3a23d6eab45f63..1a67f007839a13 100644 --- a/crates/ty_python_semantic/src/types/display.rs +++ b/crates/ty_python_semantic/src/types/display.rs @@ -2084,12 +2084,13 @@ impl<'db> FmtDetailed<'db> for DisplayParameters<'_, 'db> { fn fmt_detailed(&self, f: &mut TypeWriter<'_, '_, 'db>) -> fmt::Result { // For `ParamSpec` kind, the parameters still contain `*args` and `**kwargs`, but we // display them as `**P` instead, so avoid multiline in that case. - // TODO: This might change once we support `Concatenate` + // For `Gradual` kind without prefix params (len <= 2), display as `...`. let multiline = self.settings.multiline && self.parameters.len() > 1 + && !matches!(self.parameters.kind(), ParametersKind::ParamSpec(_)) && !matches!( self.parameters.kind(), - ParametersKind::Gradual | ParametersKind::ParamSpec(_) + ParametersKind::Gradual | ParametersKind::Top ); // Opening parenthesis f.write_char('(')?; @@ -2097,7 +2098,7 @@ impl<'db> FmtDetailed<'db> for DisplayParameters<'_, 'db> { f.write_str("\n ")?; } match self.parameters.kind() { - ParametersKind::Standard => { + ParametersKind::Standard | ParametersKind::Concatenate(_) => { let mut star_added = false; let mut needs_slash = false; let mut first = true; @@ -2149,9 +2150,9 @@ impl<'db> FmtDetailed<'db> for DisplayParameters<'_, 'db> { } } ParametersKind::Gradual | ParametersKind::Top => { - // We represent gradual form as `...` in the signature, internally the parameters still - // contain `(*args, **kwargs)` parameters. (Top parameters are displayed the same - // as gradual parameters, we just wrap the entire signature in `Top[]`.) + // We represent gradual form as `...` in the signature, internally the parameters + // still contain `(*args, **kwargs)` parameters. (Top parameters are displayed the + // same as gradual parameters, we just wrap the entire signature in `Top[]`.) f.write_str("...")?; } ParametersKind::ParamSpec(typevar) => { diff --git a/crates/ty_python_semantic/src/types/infer/builder.rs b/crates/ty_python_semantic/src/types/infer/builder.rs index 36ff0f03012e30..8b293e3472abad 100644 --- a/crates/ty_python_semantic/src/types/infer/builder.rs +++ b/crates/ty_python_semantic/src/types/infer/builder.rs @@ -57,7 +57,7 @@ use crate::semantic_index::scope::{ use crate::semantic_index::symbol::{ScopedSymbolId, Symbol}; use crate::semantic_index::{ ApplicableConstraints, EnclosingSnapshotResult, SemanticIndex, attribute_assignments, - get_loop_header, place_table, + get_loop_header, place_table, semantic_index, }; use crate::types::builder::RecursivelyDefined; use crate::types::call::bind::{CallableDescription, MatchingOverloadIndex}; @@ -128,16 +128,16 @@ use crate::types::typed_dict::{ }; use crate::types::{ BoundTypeVarIdentity, BoundTypeVarInstance, CallDunderError, CallableBinding, CallableType, - CallableTypeKind, ClassType, DataclassParams, DynamicType, InternedConstraintSet, InternedType, - IntersectionBuilder, IntersectionType, KnownClass, KnownInstanceType, KnownUnion, - LintDiagnosticGuard, LiteralValueType, LiteralValueTypeKind, MemberLookupPolicy, - MetaclassCandidate, PEP695TypeAliasType, ParamSpecAttrKind, Parameter, ParameterForm, - Parameters, Signature, SpecialFormType, StaticClassLiteral, SubclassOfType, Truthiness, Type, - TypeAliasType, TypeAndQualifiers, TypeContext, TypeQualifiers, TypeVarBoundOrConstraints, - TypeVarBoundOrConstraintsEvaluation, TypeVarConstraints, TypeVarDefaultEvaluation, - TypeVarIdentity, TypeVarInstance, TypeVarKind, TypeVarVariance, TypedDictType, UnionBuilder, - UnionType, UnionTypeInstance, any_over_type, binding_type, definition_expression_type, - infer_complete_scope_types, infer_scope_types, todo_type, + CallableTypeKind, ClassType, ConcatenateTail, DataclassParams, DynamicType, + InternedConstraintSet, InternedType, IntersectionBuilder, IntersectionType, KnownClass, + KnownInstanceType, KnownUnion, LintDiagnosticGuard, LiteralValueType, LiteralValueTypeKind, + MemberLookupPolicy, MetaclassCandidate, PEP695TypeAliasType, ParamSpecAttrKind, Parameter, + ParameterForm, Parameters, Signature, SpecialFormType, StaticClassLiteral, SubclassOfType, + Truthiness, Type, TypeAliasType, TypeAndQualifiers, TypeContext, TypeQualifiers, + TypeVarBoundOrConstraints, TypeVarBoundOrConstraintsEvaluation, TypeVarConstraints, + TypeVarDefaultEvaluation, TypeVarIdentity, TypeVarInstance, TypeVarKind, TypeVarVariance, + TypedDictType, UnionBuilder, UnionType, UnionTypeInstance, any_over_type, binding_type, + definition_expression_type, infer_complete_scope_types, infer_scope_types, todo_type, }; use crate::types::{CallableTypes, overrides}; use crate::types::{ClassBase, add_inferred_python_version_hint_to_diagnostic}; @@ -4874,8 +4874,120 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { return Ok(Type::paramspec_value_callable(db, parameters)); } - ast::Expr::Subscript(_) => { - // TODO: Support `Concatenate[...]` + ast::Expr::Subscript(subscript) => { + let value_ty = self.infer_expression(&subscript.value, TypeContext::default()); + + if matches!(value_ty, Type::SpecialForm(SpecialFormType::Concatenate)) { + let arguments_slice = &*subscript.slice; + let arguments = if let ast::Expr::Tuple(tuple) = arguments_slice { + &*tuple.elts + } else { + std::slice::from_ref(arguments_slice) + }; + + let num_arguments = arguments.len(); + if num_arguments < 2 { + for argument in arguments { + self.infer_type_expression(argument); + } + if arguments_slice.is_tuple_expr() { + self.store_expression_type(arguments_slice, Type::unknown()); + } + return Ok(Type::paramspec_value_callable( + db, + Parameters::gradual_form(), + )); + } + + let (prefix_args, last_arg) = arguments.split_at(arguments.len() - 1); + let last_arg = &last_arg[0]; + + let mut params: Vec> = Vec::with_capacity(num_arguments); + for arg in prefix_args { + let ty = self.infer_type_expression(arg); + params.push(Parameter::positional_only(None).with_annotated_type(ty)); + } + + let result = match last_arg { + ast::Expr::EllipsisLiteral(_) => { + self.infer_type_expression(last_arg); + params.push( + Parameter::variadic(Name::new_static("args")) + .with_annotated_type(Type::Dynamic(DynamicType::Any)), + ); + params.push( + Parameter::keyword_variadic(Name::new_static("kwargs")) + .with_annotated_type(Type::Dynamic(DynamicType::Any)), + ); + Some( + Parameters::new(db, params) + .into_concatenate(ConcatenateTail::Gradual), + ) + } + ast::Expr::Name(name) if !name.is_invalid() => { + let name_ty = self.infer_name_load(name); + if let Type::KnownInstance(KnownInstanceType::TypeVar(typevar)) = + name_ty + && typevar.is_paramspec(db) + { + let index = semantic_index(db, self.scope().file(db)); + if let Some(bound_typevar) = bind_typevar( + db, + index, + self.scope().file_scope_id(db), + self.typevar_binding_context, + typevar, + ) { + params.push( + Parameter::variadic(Name::new_static("args")) + .with_annotated_type(Type::TypeVar( + bound_typevar.with_paramspec_attr( + db, + ParamSpecAttrKind::Args, + ), + )), + ); + params.push( + Parameter::keyword_variadic(Name::new_static("kwargs")) + .with_annotated_type(Type::TypeVar( + bound_typevar.with_paramspec_attr( + db, + ParamSpecAttrKind::Kwargs, + ), + )), + ); + Some(Parameters::new(db, params).into_concatenate( + ConcatenateTail::ParamSpec(bound_typevar), + )) + } else { + None + } + } else { + None + } + } + _ => { + self.infer_type_expression(last_arg); + None + } + }; + + if arguments_slice.is_tuple_expr() { + let inferred_type = if result.is_some() { + todo_type!("`Concatenate[]` special form") + } else { + Type::unknown() + }; + self.store_expression_type(arguments_slice, inferred_type); + } + + return Ok(Type::paramspec_value_callable( + db, + result.unwrap_or_else(Parameters::todo), + )); + } + + // Non-Concatenate subscript: fall back to todo return Ok(Type::paramspec_value_callable(db, Parameters::todo())); } @@ -16033,56 +16145,6 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { )); } Type::SpecialForm(SpecialFormType::Callable) => { - let arguments = if let ast::Expr::Tuple(tuple) = &*subscript.slice { - &*tuple.elts - } else { - std::slice::from_ref(&*subscript.slice) - }; - - // TODO: Remove this once we support Concatenate properly. This is necessary - // to avoid a lot of false positives downstream, because we can't represent the typevar- - // specialized `Callable` types yet. - let num_arguments = arguments.len(); - if num_arguments == 2 { - let first_arg = &arguments[0]; - let second_arg = &arguments[1]; - - if first_arg.is_subscript_expr() { - let first_arg_ty = self.infer_expression(first_arg, TypeContext::default()); - if let Type::Dynamic(DynamicType::UnknownGeneric(generic_context)) = - first_arg_ty - { - let mut variables = generic_context - .variables(self.db()) - .collect::>(); - - let return_ty = - self.infer_expression(second_arg, TypeContext::default()); - return_ty.bind_and_find_all_legacy_typevars( - self.db(), - self.typevar_binding_context, - &mut variables, - ); - - let generic_context = - GenericContext::from_typevar_instances(self.db(), variables); - return Type::Dynamic(DynamicType::UnknownGeneric(generic_context)); - } - - if let Some(builder) = - self.context.report_lint(&INVALID_TYPE_FORM, subscript) - { - builder.into_diagnostic(format_args!( - "The first argument to `Callable` must be either a list of types, \ - ParamSpec, Concatenate, or `...`", - )); - } - return Type::KnownInstance(KnownInstanceType::Callable( - CallableType::unknown(self.db()), - )); - } - } - let callable = self .infer_callable_type(subscript) .as_callable() diff --git a/crates/ty_python_semantic/src/types/infer/builder/type_expression.rs b/crates/ty_python_semantic/src/types/infer/builder/type_expression.rs index f41fce6f6fc9e2..aa3e8e779ba29e 100644 --- a/crates/ty_python_semantic/src/types/infer/builder/type_expression.rs +++ b/crates/ty_python_semantic/src/types/infer/builder/type_expression.rs @@ -10,14 +10,16 @@ use crate::types::diagnostic::{ }; use crate::types::generics::bind_typevar; use crate::types::infer::builder::InnerExpressionInferenceState; -use crate::types::signatures::Signature; +use crate::types::signatures::{ConcatenateTail, Signature}; use crate::types::string_annotation::parse_string_annotation; use crate::types::tuple::{TupleSpecBuilder, TupleType}; +use ruff_python_ast::name::Name; + use crate::types::{ BindingContext, CallableType, DynamicType, GenericContext, IntersectionBuilder, KnownClass, - KnownInstanceType, LintDiagnosticGuard, LiteralValueTypeKind, Parameter, Parameters, - SpecialFormType, SubclassOfType, Type, TypeAliasType, TypeContext, TypeGuardType, TypeIsType, - TypeMapping, TypeVarKind, UnionBuilder, UnionType, any_over_type, todo_type, + KnownInstanceType, LintDiagnosticGuard, LiteralValueTypeKind, ParamSpecAttrKind, Parameter, + Parameters, SpecialFormType, SubclassOfType, Type, TypeAliasType, TypeContext, TypeGuardType, + TypeIsType, TypeMapping, TypeVarKind, UnionBuilder, UnionType, any_over_type, todo_type, }; /// Type expressions @@ -1921,8 +1923,127 @@ impl<'db> TypeInferenceBuilder<'db, '_> { } ast::Expr::Subscript(subscript) => { let value_ty = self.infer_expression(&subscript.value, TypeContext::default()); + + if matches!(value_ty, Type::SpecialForm(SpecialFormType::Concatenate)) { + let arguments_slice = &*subscript.slice; + let arguments = if let ast::Expr::Tuple(tuple) = arguments_slice { + &*tuple.elts + } else { + std::slice::from_ref(arguments_slice) + }; + + let num_arguments = arguments.len(); + if num_arguments < 2 { + // Validation: Concatenate needs at least 2 args. + // Still infer all argument types for side effects. + for argument in arguments { + self.infer_type_expression(argument); + } + if let Some(builder) = + self.context.report_lint(&INVALID_TYPE_FORM, subscript) + { + builder.into_diagnostic(format_args!( + "Special form `typing.Concatenate` expected at least 2 parameters but got {num_arguments}", + )); + } + if arguments_slice.is_tuple_expr() { + self.store_expression_type(arguments_slice, Type::unknown()); + } + return Some(Parameters::gradual_form()); + } + + let (prefix_args, last_arg) = arguments.split_at(arguments.len() - 1); + let last_arg = &last_arg[0]; + + // Infer prefix argument types as positional-only parameters. + let mut params: Vec> = Vec::with_capacity(num_arguments); + for arg in prefix_args { + let ty = self.infer_type_expression(arg); + params.push(Parameter::positional_only(None).with_annotated_type(ty)); + } + + // The last argument must be a ParamSpec or `...`. + let result = match last_arg { + ast::Expr::EllipsisLiteral(_) => { + self.infer_type_expression(last_arg); + // Gradual form: prefix params + *args: Any, **kwargs: Any + params.push( + Parameter::variadic(Name::new_static("args")) + .with_annotated_type(Type::Dynamic(DynamicType::Any)), + ); + params.push( + Parameter::keyword_variadic(Name::new_static("kwargs")) + .with_annotated_type(Type::Dynamic(DynamicType::Any)), + ); + Some( + Parameters::new(self.db(), params) + .into_concatenate(ConcatenateTail::Gradual), + ) + } + ast::Expr::Name(name) if !name.is_invalid() => { + let name_ty = self.infer_name_load(name); + if let Type::KnownInstance(KnownInstanceType::TypeVar(typevar)) = + name_ty + && typevar.is_paramspec(self.db()) + { + let index = semantic_index(self.db(), self.scope().file(self.db())); + if let Some(bound_typevar) = bind_typevar( + self.db(), + index, + self.scope().file_scope_id(self.db()), + self.typevar_binding_context, + typevar, + ) { + // Prefix params + *P.args, **P.kwargs + params.push( + Parameter::variadic(Name::new_static("args")) + .with_annotated_type(Type::TypeVar( + bound_typevar.with_paramspec_attr( + self.db(), + ParamSpecAttrKind::Args, + ), + )), + ); + params.push( + Parameter::keyword_variadic(Name::new_static("kwargs")) + .with_annotated_type(Type::TypeVar( + bound_typevar.with_paramspec_attr( + self.db(), + ParamSpecAttrKind::Kwargs, + ), + )), + ); + Some(Parameters::new(self.db(), params).into_concatenate( + ConcatenateTail::ParamSpec(bound_typevar), + )) + } else { + None + } + } else { + // Not a ParamSpec — infer the type expression for side effects + // (it was already inferred via infer_name_load above) + None + } + } + _ => { + self.infer_type_expression(last_arg); + None + } + }; + + let inferred_type = if result.is_some() { + todo_type!("`Concatenate[]` special form") + } else { + Type::unknown() + }; + if arguments_slice.is_tuple_expr() { + self.store_expression_type(arguments_slice, inferred_type); + } + return Some(result.unwrap_or_else(Parameters::todo)); + } + self.infer_subscript_type_expression(subscript, value_ty); - // TODO: Support `Concatenate[...]` + // Non-Concatenate subscript (e.g. Unpack): fall back to todo return Some(Parameters::todo()); } ast::Expr::Name(name) => { diff --git a/crates/ty_python_semantic/src/types/signatures.rs b/crates/ty_python_semantic/src/types/signatures.rs index 53f7aef427f760..866531625a208b 100644 --- a/crates/ty_python_semantic/src/types/signatures.rs +++ b/crates/ty_python_semantic/src/types/signatures.rs @@ -1349,12 +1349,36 @@ impl<'db> Signature<'db> { // If either of the parameter lists is gradual (`...`), then it is assignable to and from // any other parameter list, but not a subtype or supertype of any other parameter list. if self.parameters.is_gradual() || other.parameters.is_gradual() { - result.intersect( - db, - ConstraintSet::from( - relation.is_assignability() || relation.is_constraint_set_assignability(), - ), + if !(relation.is_assignability() || relation.is_constraint_set_assignability()) { + return ConstraintSet::from(false); + } + + // For Concatenate[T, ...] forms, check that the prefix params are compatible. + // The prefix params precede the *args/**kwargs gradual tail, so prefix_len = total_params - 2. + let self_is_concat_gradual = matches!( + self.parameters.kind(), + ParametersKind::Concatenate(ConcatenateTail::Gradual) + ); + let other_is_concat_gradual = matches!( + other.parameters.kind(), + ParametersKind::Concatenate(ConcatenateTail::Gradual) ); + + if self_is_concat_gradual && other_is_concat_gradual { + let self_prefix_len = self.parameters.len() - 2; + let other_prefix_len = other.parameters.len() - 2; + let common = self_prefix_len.min(other_prefix_len); + for i in 0..common { + // Parameters are contravariant + if !check_types( + other.parameters.as_slice()[i].annotated_type(), + self.parameters.as_slice()[i].annotated_type(), + ) { + return result; + } + } + } + return result; } @@ -1416,7 +1440,100 @@ impl<'db> Signature<'db> { return result; } - (None, None) => {} + (None, None) => { + // Check for Concatenate-style signatures: prefix params followed by + // *P.args, **P.kwargs. These are not detected by as_paramspec() because + // the prefix params change the ParametersKind from ParamSpec to Standard. + + // Case: `other` has Concatenate pattern, `self` is concrete + if let Some((other_prefix, other_paramspec)) = + other.parameters.find_paramspec_from_args_kwargs(db) + && !other_prefix.is_empty() + { + let self_params = self.parameters.as_slice(); + if self_params.len() >= other_prefix.len() { + // Check prefix param types (contravariant) + for (other_param, self_param) in + other_prefix.iter().zip(self_params.iter()) + { + result.intersect( + db, + other_param.annotated_type().has_relation_to_impl( + db, + self_param.annotated_type(), + inferable, + relation, + relation_visitor, + disjointness_visitor, + ), + ); + } + + // Bind the ParamSpec to the remaining self params + let remaining_params = &self_params[other_prefix.len()..]; + let lower = Type::Callable(CallableType::new( + db, + CallableSignature::single(Signature::new( + Parameters::new(db, remaining_params.iter().cloned()), + Type::unknown(), + )), + CallableTypeKind::ParamSpecValue, + )); + let param_spec_matches = ConstraintSet::constrain_typevar( + db, + other_paramspec, + lower, + Type::object(), + ); + result.intersect(db, param_spec_matches); + return result; + } + } + + // Case: `self` has Concatenate pattern, `other` is concrete + if let Some((self_prefix, self_paramspec)) = + self.parameters.find_paramspec_from_args_kwargs(db) + && !self_prefix.is_empty() + { + let other_params = other.parameters.as_slice(); + if other_params.len() >= self_prefix.len() { + // Check prefix param types (contravariant) + for (self_param, other_param) in + self_prefix.iter().zip(other_params.iter()) + { + result.intersect( + db, + self_param.annotated_type().has_relation_to_impl( + db, + other_param.annotated_type(), + inferable, + relation, + relation_visitor, + disjointness_visitor, + ), + ); + } + + let remaining_params = &other_params[self_prefix.len()..]; + let upper = Type::Callable(CallableType::new( + db, + CallableSignature::single(Signature::new( + Parameters::new(db, remaining_params.iter().cloned()), + Type::unknown(), + )), + CallableTypeKind::ParamSpecValue, + )); + let param_spec_matches = ConstraintSet::constrain_typevar( + db, + self_paramspec, + Type::Never, + upper, + ); + result.intersect(db, param_spec_matches); + return result; + } + } + } } } @@ -1741,9 +1858,14 @@ impl<'db> VarianceInferable<'db> for &Signature<'db> { } } -// TODO: the spec also allows signatures like `Concatenate[int, ...]` or `Concatenate[int, P]`, -// which have some number of required positional-only parameters followed by a gradual form or a -// `ParamSpec`. Our representation will need some adjustments to represent that. +/// The tail of a `Concatenate[T, ..., tail]` form. +#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, salsa::Update, get_size2::GetSize)] +pub(crate) enum ConcatenateTail<'db> { + /// `Concatenate[T, ...]` — prefix params followed by a gradual `*args: Any, **kwargs: Any`. + Gradual, + /// `Concatenate[T, P]` — prefix params followed by a `ParamSpec`. + ParamSpec(BoundTypeVarInstance<'db>), +} /// The kind of parameter list represented. #[derive(Copy, Clone, Debug, Default, PartialEq, Eq, Hash, salsa::Update, get_size2::GetSize)] @@ -1777,6 +1899,10 @@ pub(crate) enum ParametersKind<'db> { // TODO: Maybe we should use `find_paramspec_from_args_kwargs` instead of storing the typevar // here? ParamSpec(BoundTypeVarInstance<'db>), + + /// Represents a `Concatenate[T, ..., tail]` form: some number of required positional-only + /// prefix parameters followed by either a gradual form or a `ParamSpec`. + Concatenate(ConcatenateTail<'db>), } #[derive(Clone, Debug, PartialEq, Eq, Hash, salsa::Update, get_size2::GetSize)] @@ -1846,7 +1972,19 @@ impl<'db> Parameters<'db> { } pub(crate) const fn is_gradual(&self) -> bool { - matches!(self.kind, ParametersKind::Gradual) + matches!( + self.kind, + ParametersKind::Gradual | ParametersKind::Concatenate(ConcatenateTail::Gradual) + ) + } + + /// Set the kind to `Concatenate(tail)`, used for `Concatenate[T, ..., tail]` forms where + /// prefix parameters precede either a gradual or `ParamSpec` suffix. + pub(crate) fn into_concatenate(self, tail: ConcatenateTail<'db>) -> Self { + Self { + kind: ParametersKind::Concatenate(tail), + ..self + } } pub(crate) const fn is_top(&self) -> bool { @@ -2116,7 +2254,10 @@ impl<'db> Parameters<'db> { visitor: &ApplyTypeMappingVisitor<'db>, ) -> Self { if let TypeMapping::Materialize(materialization_kind) = type_mapping - && self.kind == ParametersKind::Gradual + && matches!( + self.kind, + ParametersKind::Gradual | ParametersKind::Concatenate(ConcatenateTail::Gradual) + ) { match materialization_kind { MaterializationKind::Bottom => {