diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index 1c14cf6d18cad..e1f46d9f9cbf9 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -59,6 +59,7 @@ use crate::types::mro::{Mro, MroError, MroIterator}; pub(crate) use crate::types::narrow::infer_narrowing_constraint; use crate::types::signatures::{Parameter, ParameterForm, Parameters, walk_signature}; use crate::types::tuple::{TupleSpec, TupleType}; +use crate::types::visitor::{TypeVisitor, TypeVisitorResult}; pub use crate::util::diagnostics::add_inferred_python_version_hint_to_diagnostic; use crate::{Db, FxOrderSet, Module, Program}; pub(crate) use class::{ClassLiteral, ClassType, GenericAlias, KnownClass}; @@ -390,17 +391,18 @@ pub struct PropertyInstanceType<'db> { setter: Option>, } -fn walk_property_instance_type<'db, V: visitor::TypeVisitor<'db> + ?Sized>( +fn walk_property_instance_type<'db, V: TypeVisitor<'db> + ?Sized>( db: &'db dyn Db, property: PropertyInstanceType<'db>, visitor: &mut V, -) { +) -> TypeVisitorResult { if let Some(getter) = property.getter(db) { - visitor.visit_type(db, getter); + visitor.visit_type(db, getter)?; } if let Some(setter) = property.setter(db) { - visitor.visit_type(db, setter); + visitor.visit_type(db, setter)?; } + Ok(()) } // The Salsa heap is tracked separately. @@ -425,19 +427,6 @@ impl<'db> PropertyInstanceType<'db> { ) } - fn find_legacy_typevars( - self, - db: &'db dyn Db, - typevars: &mut FxOrderSet>, - ) { - if let Some(ty) = self.getter(db) { - ty.find_legacy_typevars(db, typevars); - } - if let Some(ty) = self.setter(db) { - ty.find_legacy_typevars(db, typevars); - } - } - fn materialize(self, db: &'db dyn Db, variance: TypeVarVariance) -> Self { Self::new( db, @@ -5561,110 +5550,6 @@ impl<'db> Type<'db> { } } - /// Locates any legacy `TypeVar`s in this type, and adds them to a set. This is used to build - /// up a generic context from any legacy `TypeVar`s that appear in a function parameter list or - /// `Generic` specialization. - pub(crate) fn find_legacy_typevars( - self, - db: &'db dyn Db, - typevars: &mut FxOrderSet>, - ) { - match self { - Type::TypeVar(typevar) => { - if typevar.is_legacy(db) { - typevars.insert(typevar); - } - } - - Type::FunctionLiteral(function) => function.find_legacy_typevars(db, typevars), - - Type::BoundMethod(method) => { - method.self_instance(db).find_legacy_typevars(db, typevars); - method.function(db).find_legacy_typevars(db, typevars); - } - - Type::MethodWrapper( - MethodWrapperKind::FunctionTypeDunderGet(function) - | MethodWrapperKind::FunctionTypeDunderCall(function), - ) => { - function.find_legacy_typevars(db, typevars); - } - - Type::MethodWrapper( - MethodWrapperKind::PropertyDunderGet(property) - | MethodWrapperKind::PropertyDunderSet(property), - ) => { - property.find_legacy_typevars(db, typevars); - } - - Type::Callable(callable) => { - callable.find_legacy_typevars(db, typevars); - } - - Type::PropertyInstance(property) => { - property.find_legacy_typevars(db, typevars); - } - - Type::Union(union) => { - for element in union.iter(db) { - element.find_legacy_typevars(db, typevars); - } - } - Type::Intersection(intersection) => { - for positive in intersection.positive(db) { - positive.find_legacy_typevars(db, typevars); - } - for negative in intersection.negative(db) { - negative.find_legacy_typevars(db, typevars); - } - } - - Type::Tuple(tuple) => { - tuple.find_legacy_typevars(db, typevars); - } - - Type::GenericAlias(alias) => { - alias.find_legacy_typevars(db, typevars); - } - - Type::NominalInstance(instance) => { - instance.find_legacy_typevars(db, typevars); - } - - Type::ProtocolInstance(instance) => { - instance.find_legacy_typevars(db, typevars); - } - - Type::SubclassOf(subclass_of) => { - subclass_of.find_legacy_typevars(db, typevars); - } - - Type::TypeIs(type_is) => { - type_is.return_type(db).find_legacy_typevars(db, typevars); - } - - Type::Dynamic(_) - | Type::Never - | Type::AlwaysTruthy - | Type::AlwaysFalsy - | Type::WrapperDescriptor(_) - | Type::MethodWrapper(MethodWrapperKind::StrStartswith(_)) - | Type::DataclassDecorator(_) - | Type::DataclassTransformer(_) - | Type::ModuleLiteral(_) - | Type::ClassLiteral(_) - | Type::IntLiteral(_) - | Type::BooleanLiteral(_) - | Type::LiteralString - | Type::StringLiteral(_) - | Type::BytesLiteral(_) - | Type::EnumLiteral(_) - | Type::BoundSuper(_) - | Type::SpecialForm(_) - | Type::KnownInstance(_) => {} - } - } - /// Return the string representation of this type when converted to string as it would be /// provided by the `__str__` method. /// @@ -5889,20 +5774,21 @@ pub enum TypeMapping<'a, 'db> { PromoteLiterals, } -fn walk_type_mapping<'db, V: visitor::TypeVisitor<'db> + ?Sized>( +fn walk_type_mapping<'db, V: TypeVisitor<'db> + ?Sized>( db: &'db dyn Db, mapping: &TypeMapping<'_, 'db>, visitor: &mut V, -) { +) -> TypeVisitorResult { match mapping { TypeMapping::Specialization(specialization) => { - walk_specialization(db, *specialization, visitor); + walk_specialization(db, *specialization, visitor)?; } TypeMapping::PartialSpecialization(specialization) => { - walk_partial_specialization(db, specialization, visitor); + walk_partial_specialization(db, specialization, visitor)?; } TypeMapping::PromoteLiterals => {} } + Ok(()) } impl<'db> TypeMapping<'_, 'db> { @@ -5975,29 +5861,30 @@ pub enum KnownInstanceType<'db> { Field(FieldInstance<'db>), } -fn walk_known_instance_type<'db, V: visitor::TypeVisitor<'db> + ?Sized>( +fn walk_known_instance_type<'db, V: TypeVisitor<'db> + ?Sized>( db: &'db dyn Db, known_instance: KnownInstanceType<'db>, visitor: &mut V, -) { +) -> TypeVisitorResult { match known_instance { KnownInstanceType::SubscriptedProtocol(context) | KnownInstanceType::SubscriptedGeneric(context) => { - walk_generic_context(db, context, visitor); + walk_generic_context(db, context, visitor)?; } KnownInstanceType::TypeVar(typevar) => { - visitor.visit_type_var_type(db, typevar); + visitor.visit_type_var_type(db, typevar)?; } KnownInstanceType::TypeAliasType(type_alias) => { - visitor.visit_type_alias_type(db, type_alias); + visitor.visit_type_alias_type(db, type_alias)?; } KnownInstanceType::Deprecated(_) => { // Nothing to visit } KnownInstanceType::Field(field) => { - visitor.visit_type(db, field.default_type(db)); + visitor.visit_type(db, field.default_type(db))?; } } + Ok(()) } impl<'db> KnownInstanceType<'db> { @@ -6475,17 +6362,18 @@ pub struct TypeVarInstance<'db> { // The Salsa heap is tracked separately. impl get_size2::GetSize for TypeVarInstance<'_> {} -fn walk_type_var_type<'db, V: visitor::TypeVisitor<'db> + ?Sized>( +fn walk_type_var_type<'db, V: TypeVisitor<'db> + ?Sized>( db: &'db dyn Db, type_var: TypeVarInstance<'db>, visitor: &mut V, -) { +) -> TypeVisitorResult { if let Some(bounds) = type_var.bound_or_constraints(db) { - walk_type_var_bounds(db, bounds, visitor); + walk_type_var_bounds(db, bounds, visitor)?; } if let Some(default_type) = type_var.default_ty(db) { - visitor.visit_type(db, default_type); + visitor.visit_type(db, default_type)?; } + Ok(()) } impl<'db> TypeVarInstance<'db> { @@ -6568,17 +6456,18 @@ pub enum TypeVarBoundOrConstraints<'db> { Constraints(UnionType<'db>), } -fn walk_type_var_bounds<'db, V: visitor::TypeVisitor<'db> + ?Sized>( +fn walk_type_var_bounds<'db, V: TypeVisitor<'db> + ?Sized>( db: &'db dyn Db, bounds: TypeVarBoundOrConstraints<'db>, visitor: &mut V, -) { +) -> TypeVisitorResult { match bounds { - TypeVarBoundOrConstraints::UpperBound(bound) => visitor.visit_type(db, bound), + TypeVarBoundOrConstraints::UpperBound(bound) => visitor.visit_type(db, bound)?, TypeVarBoundOrConstraints::Constraints(constraints) => { - visitor.visit_union_type(db, constraints); + visitor.visit_union_type(db, constraints)?; } } + Ok(()) } impl<'db> TypeVarBoundOrConstraints<'db> { @@ -7485,13 +7374,14 @@ pub struct BoundMethodType<'db> { // The Salsa heap is tracked separately. impl get_size2::GetSize for BoundMethodType<'_> {} -fn walk_bound_method_type<'db, V: visitor::TypeVisitor<'db> + ?Sized>( +fn walk_bound_method_type<'db, V: TypeVisitor<'db> + ?Sized>( db: &'db dyn Db, method: BoundMethodType<'db>, visitor: &mut V, -) { - visitor.visit_function_type(db, method.function(db)); - visitor.visit_type(db, method.self_instance(db)); +) -> TypeVisitorResult { + visitor.visit_function_type(db, method.function(db))?; + visitor.visit_type(db, method.self_instance(db))?; + Ok(()) } impl<'db> BoundMethodType<'db> { @@ -7559,14 +7449,15 @@ pub struct CallableType<'db> { is_function_like: bool, } -pub(super) fn walk_callable_type<'db, V: visitor::TypeVisitor<'db> + ?Sized>( +pub(super) fn walk_callable_type<'db, V: TypeVisitor<'db> + ?Sized>( db: &'db dyn Db, ty: CallableType<'db>, visitor: &mut V, -) { +) -> TypeVisitorResult { for signature in &ty.signatures(db).overloads { - walk_signature(db, signature, visitor); + walk_signature(db, signature, visitor)?; } + Ok(()) } // The Salsa heap is tracked separately. @@ -7642,14 +7533,6 @@ impl<'db> CallableType<'db> { ) } - fn find_legacy_typevars( - self, - db: &'db dyn Db, - typevars: &mut FxOrderSet>, - ) { - self.signatures(db).find_legacy_typevars(db, typevars); - } - /// Check whether this callable type has the given relation to another callable type. /// /// See [`Type::is_subtype_of`] and [`Type::is_assignable_to`] for more details. @@ -7697,28 +7580,29 @@ pub enum MethodWrapperKind<'db> { StrStartswith(StringLiteralType<'db>), } -pub(super) fn walk_method_wrapper_type<'db, V: visitor::TypeVisitor<'db> + ?Sized>( +pub(super) fn walk_method_wrapper_type<'db, V: TypeVisitor<'db> + ?Sized>( db: &'db dyn Db, method_wrapper: MethodWrapperKind<'db>, visitor: &mut V, -) { +) -> TypeVisitorResult { match method_wrapper { MethodWrapperKind::FunctionTypeDunderGet(function) => { - visitor.visit_function_type(db, function); + visitor.visit_function_type(db, function)?; } MethodWrapperKind::FunctionTypeDunderCall(function) => { - visitor.visit_function_type(db, function); + visitor.visit_function_type(db, function)?; } MethodWrapperKind::PropertyDunderGet(property) => { - visitor.visit_property_instance_type(db, property); + visitor.visit_property_instance_type(db, property)?; } MethodWrapperKind::PropertyDunderSet(property) => { - visitor.visit_property_instance_type(db, property); + visitor.visit_property_instance_type(db, property)?; } MethodWrapperKind::StrStartswith(string_literal) => { - visitor.visit_type(db, Type::StringLiteral(string_literal)); + visitor.visit_type(db, Type::StringLiteral(string_literal))?; } } + Ok(()) } impl<'db> MethodWrapperKind<'db> { @@ -7918,12 +7802,13 @@ pub struct PEP695TypeAliasType<'db> { // The Salsa heap is tracked separately. impl get_size2::GetSize for PEP695TypeAliasType<'_> {} -fn walk_pep_695_type_alias<'db, V: visitor::TypeVisitor<'db> + ?Sized>( +fn walk_pep_695_type_alias<'db, V: TypeVisitor<'db> + ?Sized>( db: &'db dyn Db, type_alias: PEP695TypeAliasType<'db>, visitor: &mut V, -) { - visitor.visit_type(db, type_alias.value_type(db)); +) -> TypeVisitorResult { + visitor.visit_type(db, type_alias.value_type(db))?; + Ok(()) } #[salsa::tracked] @@ -7965,12 +7850,13 @@ pub struct BareTypeAliasType<'db> { // The Salsa heap is tracked separately. impl get_size2::GetSize for BareTypeAliasType<'_> {} -fn walk_bare_type_alias<'db, V: visitor::TypeVisitor<'db> + ?Sized>( +fn walk_bare_type_alias<'db, V: TypeVisitor<'db> + ?Sized>( db: &'db dyn Db, type_alias: BareTypeAliasType<'db>, visitor: &mut V, -) { - visitor.visit_type(db, type_alias.value(db)); +) -> TypeVisitorResult { + visitor.visit_type(db, type_alias.value(db))?; + Ok(()) } impl<'db> BareTypeAliasType<'db> { @@ -7992,19 +7878,20 @@ pub enum TypeAliasType<'db> { Bare(BareTypeAliasType<'db>), } -fn walk_type_alias_type<'db, V: visitor::TypeVisitor<'db> + ?Sized>( +fn walk_type_alias_type<'db, V: TypeVisitor<'db> + ?Sized>( db: &'db dyn Db, type_alias: TypeAliasType<'db>, visitor: &mut V, -) { +) -> TypeVisitorResult { match type_alias { TypeAliasType::PEP695(type_alias) => { - walk_pep_695_type_alias(db, type_alias, visitor); + walk_pep_695_type_alias(db, type_alias, visitor)?; } TypeAliasType::Bare(type_alias) => { - walk_bare_type_alias(db, type_alias, visitor); + walk_bare_type_alias(db, type_alias, visitor)?; } } + Ok(()) } impl<'db> TypeAliasType<'db> { @@ -8059,14 +7946,15 @@ pub struct UnionType<'db> { pub elements: Box<[Type<'db>]>, } -pub(crate) fn walk_union<'db, V: visitor::TypeVisitor<'db> + ?Sized>( +pub(crate) fn walk_union<'db, V: TypeVisitor<'db> + ?Sized>( db: &'db dyn Db, union: UnionType<'db>, visitor: &mut V, -) { +) -> TypeVisitorResult { for element in union.elements(db) { - visitor.visit_type(db, *element); + visitor.visit_type(db, *element)?; } + Ok(()) } // The Salsa heap is tracked separately. @@ -8296,17 +8184,18 @@ pub struct IntersectionType<'db> { // The Salsa heap is tracked separately. impl get_size2::GetSize for IntersectionType<'_> {} -pub(super) fn walk_intersection_type<'db, V: visitor::TypeVisitor<'db> + ?Sized>( +pub(super) fn walk_intersection_type<'db, V: TypeVisitor<'db> + ?Sized>( db: &'db dyn Db, intersection: IntersectionType<'db>, visitor: &mut V, -) { +) -> TypeVisitorResult { for element in intersection.positive(db) { - visitor.visit_type(db, *element); + visitor.visit_type(db, *element)?; } for element in intersection.negative(db) { - visitor.visit_type(db, *element); + visitor.visit_type(db, *element)?; } + Ok(()) } impl<'db> IntersectionType<'db> { @@ -8696,13 +8585,14 @@ pub struct BoundSuperType<'db> { // The Salsa heap is tracked separately. impl get_size2::GetSize for BoundSuperType<'_> {} -fn walk_bound_super_type<'db, V: visitor::TypeVisitor<'db> + ?Sized>( +fn walk_bound_super_type<'db, V: TypeVisitor<'db> + ?Sized>( db: &'db dyn Db, bound_super: BoundSuperType<'db>, visitor: &mut V, -) { - visitor.visit_type(db, bound_super.pivot_class(db).into()); - visitor.visit_type(db, bound_super.owner(db).into_type()); +) -> TypeVisitorResult { + visitor.visit_type(db, bound_super.pivot_class(db).into())?; + visitor.visit_type(db, bound_super.owner(db).into_type())?; + Ok(()) } impl<'db> BoundSuperType<'db> { @@ -8887,12 +8777,13 @@ pub struct TypeIsType<'db> { place_info: Option<(ScopeId<'db>, ScopedPlaceId)>, } -fn walk_typeis_type<'db, V: visitor::TypeVisitor<'db> + ?Sized>( +fn walk_typeis_type<'db, V: TypeVisitor<'db> + ?Sized>( db: &'db dyn Db, typeis_type: TypeIsType<'db>, visitor: &mut V, -) { - visitor.visit_type(db, typeis_type.return_type(db)); +) -> TypeVisitorResult { + visitor.visit_type(db, typeis_type.return_type(db))?; + Ok(()) } // The Salsa heap is tracked separately. diff --git a/crates/ty_python_semantic/src/types/class.rs b/crates/ty_python_semantic/src/types/class.rs index 8ca5516bb18fe..a0c61f2163009 100644 --- a/crates/ty_python_semantic/src/types/class.rs +++ b/crates/ty_python_semantic/src/types/class.rs @@ -20,6 +20,7 @@ use crate::types::generics::{GenericContext, Specialization, walk_specialization use crate::types::infer::nearest_enclosing_class; use crate::types::signatures::{CallableSignature, Parameter, Parameters, Signature}; use crate::types::tuple::TupleType; +use crate::types::visitor::{TypeVisitor, TypeVisitorResult}; use crate::types::{ BareTypeAliasType, Binding, BoundSuperError, BoundSuperType, CallableType, DataclassParams, DeprecatedInstance, DynamicType, KnownInstanceType, TypeAliasType, TypeMapping, TypeRelation, @@ -27,7 +28,7 @@ use crate::types::{ infer_definition_types, }; use crate::{ - Db, FxOrderSet, KnownModule, Program, + Db, KnownModule, Program, module_resolver::file_to_module, place::{ Boundness, LookupError, LookupResult, Place, PlaceAndQualifiers, class_symbol, @@ -180,12 +181,13 @@ pub struct GenericAlias<'db> { pub(crate) specialization: Specialization<'db>, } -pub(super) fn walk_generic_alias<'db, V: super::visitor::TypeVisitor<'db> + ?Sized>( +pub(super) fn walk_generic_alias<'db, V: TypeVisitor<'db> + ?Sized>( db: &'db dyn Db, alias: GenericAlias<'db>, visitor: &mut V, -) { - walk_specialization(db, alias.specialization(db), visitor); +) -> TypeVisitorResult { + walk_specialization(db, alias.specialization(db), visitor)?; + Ok(()) } // The Salsa heap is tracked separately. @@ -227,16 +229,6 @@ impl<'db> GenericAlias<'db> { self.specialization(db).apply_type_mapping(db, type_mapping), ) } - - pub(super) fn find_legacy_typevars( - self, - db: &'db dyn Db, - typevars: &mut FxOrderSet>, - ) { - // A tuple's specialization will include all of its element types, so we don't need to also - // look in `self.tuple`. - self.specialization(db).find_legacy_typevars(db, typevars); - } } impl<'db> From> for Type<'db> { @@ -365,17 +357,6 @@ impl<'db> ClassType<'db> { } } - pub(super) fn find_legacy_typevars( - self, - db: &'db dyn Db, - typevars: &mut FxOrderSet>, - ) { - match self { - Self::NonGeneric(_) => {} - Self::Generic(generic) => generic.find_legacy_typevars(db, typevars), - } - } - /// Iterate over the [method resolution order] ("MRO") of the class. /// /// If the MRO could not be accurately resolved, this method falls back to iterating diff --git a/crates/ty_python_semantic/src/types/function.rs b/crates/ty_python_semantic/src/types/function.rs index d1c7140fd7b78..039b3e68300fa 100644 --- a/crates/ty_python_semantic/src/types/function.rs +++ b/crates/ty_python_semantic/src/types/function.rs @@ -74,13 +74,13 @@ use crate::types::diagnostic::{ use crate::types::generics::{GenericContext, walk_generic_context}; use crate::types::narrow::ClassInfoConstraintFunction; use crate::types::signatures::{CallableSignature, Signature}; -use crate::types::visitor::any_over_type; +use crate::types::visitor::{TypeVisitor, TypeVisitorResult, any_over_type}; use crate::types::{ BoundMethodType, CallableType, ClassLiteral, ClassType, DeprecatedInstance, DynamicType, - KnownClass, Truthiness, Type, TypeMapping, TypeRelation, TypeTransformer, TypeVarInstance, - UnionBuilder, walk_type_mapping, + KnownClass, Truthiness, Type, TypeMapping, TypeRelation, TypeTransformer, UnionBuilder, + walk_type_mapping, }; -use crate::{Db, FxOrderSet, ModuleName, resolve_module}; +use crate::{Db, ModuleName, resolve_module}; /// A collection of useful spans for annotating functions. /// @@ -429,14 +429,15 @@ pub struct FunctionLiteral<'db> { inherited_generic_context: Option>, } -fn walk_function_literal<'db, V: super::visitor::TypeVisitor<'db> + ?Sized>( +fn walk_function_literal<'db, V: TypeVisitor<'db> + ?Sized>( db: &'db dyn Db, function: FunctionLiteral<'db>, visitor: &mut V, -) { +) -> TypeVisitorResult { if let Some(context) = function.inherited_generic_context(db) { - walk_generic_context(db, context, visitor); + walk_generic_context(db, context, visitor)?; } + Ok(()) } #[salsa::tracked] @@ -596,15 +597,16 @@ pub struct FunctionType<'db> { // The Salsa heap is tracked separately. impl get_size2::GetSize for FunctionType<'_> {} -pub(super) fn walk_function_type<'db, V: super::visitor::TypeVisitor<'db> + ?Sized>( +pub(super) fn walk_function_type<'db, V: TypeVisitor<'db> + ?Sized>( db: &'db dyn Db, function: FunctionType<'db>, visitor: &mut V, -) { - walk_function_literal(db, function.literal(db), visitor); +) -> TypeVisitorResult { + walk_function_literal(db, function.literal(db), visitor)?; for mapping in function.type_mappings(db) { - walk_type_mapping(db, mapping, visitor); + walk_type_mapping(db, mapping, visitor)?; } + Ok(()) } #[salsa::tracked] @@ -854,17 +856,6 @@ impl<'db> FunctionType<'db> { self_signature.is_equivalent_to(db, other_signature) } - pub(crate) fn find_legacy_typevars( - self, - db: &'db dyn Db, - typevars: &mut FxOrderSet>, - ) { - let signatures = self.signature(db); - for signature in &signatures.overloads { - signature.find_legacy_typevars(db, typevars); - } - } - pub(crate) fn normalized(self, db: &'db dyn Db) -> Self { let mut visitor = TypeTransformer::default(); self.normalized_impl(db, &mut visitor) diff --git a/crates/ty_python_semantic/src/types/generics.rs b/crates/ty_python_semantic/src/types/generics.rs index f19a1a477e1c8..f413a9127086e 100644 --- a/crates/ty_python_semantic/src/types/generics.rs +++ b/crates/ty_python_semantic/src/types/generics.rs @@ -12,6 +12,7 @@ use crate::types::class_base::ClassBase; use crate::types::instance::{NominalInstanceType, Protocol, ProtocolInstanceType}; use crate::types::signatures::{Parameter, Parameters, Signature}; use crate::types::tuple::{TupleSpec, TupleType}; +use crate::types::visitor::{TypeVisitor, TypeVisitorControlFlow, TypeVisitorResult, visit_type}; use crate::types::{ KnownInstanceType, Type, TypeMapping, TypeRelation, TypeTransformer, TypeVarBoundOrConstraints, TypeVarInstance, TypeVarVariance, UnionType, binding_type, declaration_type, @@ -59,6 +60,23 @@ fn bound_legacy_typevars<'db>( .filter(|typevar| typevar.is_legacy(db)) } +struct FindLegacyTypeVars<'a, 'db>(&'a mut FxOrderSet>); + +impl<'db> TypeVisitor<'db> for FindLegacyTypeVars<'_, 'db> { + fn visit_type(&mut self, db: &'db dyn Db, ty: Type<'db>) -> TypeVisitorResult { + match ty { + Type::TypeVar(typevar) => { + if typevar.is_legacy(db) { + self.0.insert(typevar); + } + Err(TypeVisitorControlFlow::Prune) + } + Type::KnownInstance(_) => Err(TypeVisitorControlFlow::Prune), + _ => Ok(()), + } + } +} + /// A list of formal type variables for a generic function, class, or type alias. /// /// TODO: Handle nested generic contexts better, with actual parent links to the lexically @@ -74,14 +92,15 @@ pub struct GenericContext<'db> { pub(crate) variables: FxOrderSet>, } -pub(super) fn walk_generic_context<'db, V: super::visitor::TypeVisitor<'db> + ?Sized>( +pub(super) fn walk_generic_context<'db, V: TypeVisitor<'db> + ?Sized>( db: &'db dyn Db, context: GenericContext<'db>, visitor: &mut V, -) { +) -> TypeVisitorResult { for typevar in context.variables(db) { - visitor.visit_type_var_type(db, *typevar); + visitor.visit_type_var_type(db, *typevar)?; } + Ok(()) } // The Salsa heap is tracked separately. @@ -134,14 +153,14 @@ impl<'db> GenericContext<'db> { let mut variables = FxOrderSet::default(); for param in parameters { if let Some(ty) = param.annotated_type() { - ty.find_legacy_typevars(db, &mut variables); + visit_type(db, ty, FindLegacyTypeVars(&mut variables)); } if let Some(ty) = param.default_type() { - ty.find_legacy_typevars(db, &mut variables); + visit_type(db, ty, FindLegacyTypeVars(&mut variables)); } } if let Some(ty) = return_type { - ty.find_legacy_typevars(db, &mut variables); + visit_type(db, ty, FindLegacyTypeVars(&mut variables)); } // Then remove any that were bound in enclosing scopes. @@ -167,7 +186,7 @@ impl<'db> GenericContext<'db> { ) -> Option { let mut variables = FxOrderSet::default(); for base in bases { - base.find_legacy_typevars(db, &mut variables); + visit_type(db, base, FindLegacyTypeVars(&mut variables)); } if variables.is_empty() { return None; @@ -359,18 +378,19 @@ pub struct Specialization<'db> { tuple_inner: Option>, } -pub(super) fn walk_specialization<'db, V: super::visitor::TypeVisitor<'db> + ?Sized>( +pub(super) fn walk_specialization<'db, V: TypeVisitor<'db> + ?Sized>( db: &'db dyn Db, specialization: Specialization<'db>, visitor: &mut V, -) { - walk_generic_context(db, specialization.generic_context(db), visitor); +) -> TypeVisitorResult { + walk_generic_context(db, specialization.generic_context(db), visitor)?; for ty in specialization.types(db) { - visitor.visit_type(db, *ty); + visitor.visit_type(db, *ty)?; } if let Some(tuple) = specialization.tuple_inner(db) { - visitor.visit_tuple_type(db, tuple); + visitor.visit_tuple_type(db, tuple)?; } + Ok(()) } impl<'db> Specialization<'db> { @@ -594,16 +614,6 @@ impl<'db> Specialization<'db> { true } - - pub(crate) fn find_legacy_typevars( - self, - db: &'db dyn Db, - typevars: &mut FxOrderSet>, - ) { - for ty in self.types(db) { - ty.find_legacy_typevars(db, typevars); - } - } } /// A mapping between type variables and types. @@ -616,15 +626,16 @@ pub struct PartialSpecialization<'a, 'db> { types: Cow<'a, [Type<'db>]>, } -pub(super) fn walk_partial_specialization<'db, V: super::visitor::TypeVisitor<'db> + ?Sized>( +pub(super) fn walk_partial_specialization<'db, V: TypeVisitor<'db> + ?Sized>( db: &'db dyn Db, specialization: &PartialSpecialization<'_, 'db>, visitor: &mut V, -) { - walk_generic_context(db, specialization.generic_context, visitor); +) -> TypeVisitorResult { + walk_generic_context(db, specialization.generic_context, visitor)?; for ty in &*specialization.types { - visitor.visit_type(db, *ty); + visitor.visit_type(db, *ty)?; } + Ok(()) } impl<'db> PartialSpecialization<'_, 'db> { diff --git a/crates/ty_python_semantic/src/types/instance.rs b/crates/ty_python_semantic/src/types/instance.rs index 3959b041e166f..919ec27644aed 100644 --- a/crates/ty_python_semantic/src/types/instance.rs +++ b/crates/ty_python_semantic/src/types/instance.rs @@ -4,13 +4,14 @@ use std::marker::PhantomData; use super::protocol_class::ProtocolInterface; use super::{ClassType, KnownClass, SubclassOfType, Type, TypeVarVariance}; +use crate::Db; use crate::place::PlaceAndQualifiers; use crate::types::cyclic::PairVisitor; use crate::types::enums::is_single_member_enum; use crate::types::protocol_class::walk_protocol_interface; use crate::types::tuple::TupleType; -use crate::types::{DynamicType, TypeMapping, TypeRelation, TypeTransformer, TypeVarInstance}; -use crate::{Db, FxOrderSet}; +use crate::types::visitor::{TypeVisitor, TypeVisitorResult}; +use crate::types::{DynamicType, TypeMapping, TypeRelation, TypeTransformer}; pub(super) use synthesized_protocol::SynthesizedProtocolType; @@ -77,12 +78,13 @@ pub struct NominalInstanceType<'db> { _phantom: PhantomData<()>, } -pub(super) fn walk_nominal_instance_type<'db, V: super::visitor::TypeVisitor<'db> + ?Sized>( +pub(super) fn walk_nominal_instance_type<'db, V: TypeVisitor<'db> + ?Sized>( db: &'db dyn Db, nominal: NominalInstanceType<'db>, visitor: &mut V, -) { - visitor.visit_type(db, nominal.class.into()); +) -> TypeVisitorResult { + visitor.visit_type(db, nominal.class.into())?; + Ok(()) } impl<'db> NominalInstanceType<'db> { @@ -147,14 +149,6 @@ impl<'db> NominalInstanceType<'db> { ) -> Self { Self::from_class(self.class.apply_type_mapping(db, type_mapping)) } - - pub(super) fn find_legacy_typevars( - self, - db: &'db dyn Db, - typevars: &mut FxOrderSet>, - ) { - self.class.find_legacy_typevars(db, typevars); - } } impl<'db> From> for Type<'db> { @@ -177,12 +171,13 @@ pub struct ProtocolInstanceType<'db> { _phantom: PhantomData<()>, } -pub(super) fn walk_protocol_instance_type<'db, V: super::visitor::TypeVisitor<'db> + ?Sized>( +pub(super) fn walk_protocol_instance_type<'db, V: TypeVisitor<'db> + ?Sized>( db: &'db dyn Db, protocol: ProtocolInstanceType<'db>, visitor: &mut V, -) { - walk_protocol_interface(db, protocol.inner.interface(db), visitor); +) -> TypeVisitorResult { + walk_protocol_interface(db, protocol.inner.interface(db), visitor)?; + Ok(()) } impl<'db> ProtocolInstanceType<'db> { @@ -329,21 +324,6 @@ impl<'db> ProtocolInstanceType<'db> { } } - pub(super) fn find_legacy_typevars( - self, - db: &'db dyn Db, - typevars: &mut FxOrderSet>, - ) { - match self.inner { - Protocol::FromClass(class) => { - class.find_legacy_typevars(db, typevars); - } - Protocol::Synthesized(synthesized) => { - synthesized.find_legacy_typevars(db, typevars); - } - } - } - pub(super) fn interface(self, db: &'db dyn Db) -> ProtocolInterface<'db> { self.inner.interface(db) } @@ -375,9 +355,9 @@ impl<'db> Protocol<'db> { } mod synthesized_protocol { + use crate::Db; use crate::types::protocol_class::ProtocolInterface; - use crate::types::{TypeMapping, TypeTransformer, TypeVarInstance, TypeVarVariance}; - use crate::{Db, FxOrderSet}; + use crate::types::{TypeMapping, TypeTransformer, TypeVarVariance}; /// A "synthesized" protocol type that is dissociated from a class definition in source code. /// @@ -414,14 +394,6 @@ mod synthesized_protocol { Self(self.0.specialized_and_normalized(db, type_mapping)) } - pub(super) fn find_legacy_typevars( - self, - db: &'db dyn Db, - typevars: &mut FxOrderSet>, - ) { - self.0.find_legacy_typevars(db, typevars); - } - pub(in crate::types) fn interface(self) -> ProtocolInterface<'db> { self.0 } diff --git a/crates/ty_python_semantic/src/types/protocol_class.rs b/crates/ty_python_semantic/src/types/protocol_class.rs index 4b01ce167045f..e8eeb27eccb18 100644 --- a/crates/ty_python_semantic/src/types/protocol_class.rs +++ b/crates/ty_python_semantic/src/types/protocol_class.rs @@ -7,14 +7,15 @@ use ruff_python_ast::name::Name; use super::TypeVarVariance; use crate::semantic_index::place_table; use crate::{ - Db, FxOrderSet, + Db, place::{Boundness, Place, PlaceAndQualifiers, place_from_bindings, place_from_declarations}, semantic_index::use_def_map, types::{ CallableType, ClassBase, ClassLiteral, KnownFunction, PropertyInstanceType, Signature, - Type, TypeMapping, TypeQualifiers, TypeRelation, TypeTransformer, TypeVarInstance, + Type, TypeMapping, TypeQualifiers, TypeRelation, TypeTransformer, cyclic::PairVisitor, signatures::{Parameter, Parameters}, + visitor::{TypeVisitor, TypeVisitorResult}, }, }; @@ -77,14 +78,15 @@ pub(super) struct ProtocolInterface<'db> { impl get_size2::GetSize for ProtocolInterface<'_> {} -pub(super) fn walk_protocol_interface<'db, V: super::visitor::TypeVisitor<'db> + ?Sized>( +pub(super) fn walk_protocol_interface<'db, V: TypeVisitor<'db> + ?Sized>( db: &'db dyn Db, interface: ProtocolInterface<'db>, visitor: &mut V, -) { +) -> TypeVisitorResult { for member in interface.members(db) { - walk_protocol_member(db, &member, visitor); + walk_protocol_member(db, &member, visitor)?; } + Ok(()) } impl<'db> ProtocolInterface<'db> { @@ -205,16 +207,6 @@ impl<'db> ProtocolInterface<'db> { .collect::>(), ) } - - pub(super) fn find_legacy_typevars( - self, - db: &'db dyn Db, - typevars: &mut FxOrderSet>, - ) { - for data in self.inner(db).values() { - data.find_legacy_typevars(db, typevars); - } - } } #[derive(Debug, PartialEq, Eq, Clone, Hash, salsa::Update)] @@ -242,14 +234,6 @@ impl<'db> ProtocolMemberData<'db> { } } - fn find_legacy_typevars( - &self, - db: &'db dyn Db, - typevars: &mut FxOrderSet>, - ) { - self.kind.find_legacy_typevars(db, typevars); - } - fn materialize(&self, db: &'db dyn Db, variance: TypeVarVariance) -> Self { Self { kind: self.kind.materialize(db, variance), @@ -294,18 +278,6 @@ impl<'db> ProtocolMemberKind<'db> { } } - fn find_legacy_typevars( - &self, - db: &'db dyn Db, - typevars: &mut FxOrderSet>, - ) { - match self { - ProtocolMemberKind::Method(callable) => callable.find_legacy_typevars(db, typevars), - ProtocolMemberKind::Property(property) => property.find_legacy_typevars(db, typevars), - ProtocolMemberKind::Other(ty) => ty.find_legacy_typevars(db, typevars), - } - } - fn materialize(self, db: &'db dyn Db, variance: TypeVarVariance) -> Self { match self { ProtocolMemberKind::Method(callable) => { @@ -329,18 +301,19 @@ pub(super) struct ProtocolMember<'a, 'db> { qualifiers: TypeQualifiers, } -fn walk_protocol_member<'db, V: super::visitor::TypeVisitor<'db> + ?Sized>( +fn walk_protocol_member<'db, V: TypeVisitor<'db> + ?Sized>( db: &'db dyn Db, member: &ProtocolMember<'_, 'db>, visitor: &mut V, -) { +) -> TypeVisitorResult { match member.kind { - ProtocolMemberKind::Method(method) => visitor.visit_callable_type(db, method), + ProtocolMemberKind::Method(method) => visitor.visit_callable_type(db, method)?, ProtocolMemberKind::Property(property) => { - visitor.visit_property_instance_type(db, property); + visitor.visit_property_instance_type(db, property)?; } - ProtocolMemberKind::Other(ty) => visitor.visit_type(db, ty), + ProtocolMemberKind::Other(ty) => visitor.visit_type(db, ty)?, } + Ok(()) } impl<'a, 'db> ProtocolMember<'a, 'db> { diff --git a/crates/ty_python_semantic/src/types/signatures.rs b/crates/ty_python_semantic/src/types/signatures.rs index d97aa89513321..cfa58873aff87 100644 --- a/crates/ty_python_semantic/src/types/signatures.rs +++ b/crates/ty_python_semantic/src/types/signatures.rs @@ -16,10 +16,11 @@ use itertools::EitherOrBoth; use smallvec::{SmallVec, smallvec_inline}; use super::{DynamicType, Type, TypeTransformer, TypeVarVariance, definition_expression_type}; +use crate::Db; use crate::semantic_index::definition::Definition; use crate::types::generics::{GenericContext, walk_generic_context}; -use crate::types::{TypeMapping, TypeRelation, TypeVarInstance, todo_type}; -use crate::{Db, FxOrderSet}; +use crate::types::visitor::{TypeVisitor, TypeVisitorResult}; +use crate::types::{TypeMapping, TypeRelation, todo_type}; use ruff_python_ast::{self as ast, name::Name}; /// The signature of a single callable. If the callable is overloaded, there is a separate @@ -85,16 +86,6 @@ impl<'db> CallableSignature<'db> { ) } - pub(crate) fn find_legacy_typevars( - &self, - db: &'db dyn Db, - typevars: &mut FxOrderSet>, - ) { - for signature in &self.overloads { - signature.find_legacy_typevars(db, typevars); - } - } - pub(crate) fn bind_self(&self) -> Self { Self { overloads: self.overloads.iter().map(Signature::bind_self).collect(), @@ -241,27 +232,28 @@ pub struct Signature<'db> { pub(crate) return_ty: Option>, } -pub(super) fn walk_signature<'db, V: super::visitor::TypeVisitor<'db> + ?Sized>( +pub(super) fn walk_signature<'db, V: TypeVisitor<'db> + ?Sized>( db: &'db dyn Db, signature: &Signature<'db>, visitor: &mut V, -) { +) -> TypeVisitorResult { if let Some(generic_context) = &signature.generic_context { - walk_generic_context(db, *generic_context, visitor); + walk_generic_context(db, *generic_context, visitor)?; } if let Some(inherited_generic_context) = &signature.inherited_generic_context { - walk_generic_context(db, *inherited_generic_context, visitor); + walk_generic_context(db, *inherited_generic_context, visitor)?; } // By default we usually don't visit the type of the default value, // as it isn't relevant to most things for parameter in &signature.parameters { if let Some(ty) = parameter.annotated_type() { - visitor.visit_type(db, ty); + visitor.visit_type(db, ty)?; } } if let Some(return_ty) = &signature.return_ty { - visitor.visit_type(db, *return_ty); + visitor.visit_type(db, *return_ty)?; } + Ok(()) } impl<'db> Signature<'db> { @@ -413,24 +405,6 @@ impl<'db> Signature<'db> { } } - pub(crate) fn find_legacy_typevars( - &self, - db: &'db dyn Db, - typevars: &mut FxOrderSet>, - ) { - for param in &self.parameters { - if let Some(ty) = param.annotated_type() { - ty.find_legacy_typevars(db, typevars); - } - if let Some(ty) = param.default_type() { - ty.find_legacy_typevars(db, typevars); - } - } - if let Some(ty) = self.return_ty { - ty.find_legacy_typevars(db, typevars); - } - } - /// Return the parameters in this signature. pub(crate) fn parameters(&self) -> &Parameters<'db> { &self.parameters diff --git a/crates/ty_python_semantic/src/types/subclass_of.rs b/crates/ty_python_semantic/src/types/subclass_of.rs index 5b12ae252adbe..716d0050d2fdf 100644 --- a/crates/ty_python_semantic/src/types/subclass_of.rs +++ b/crates/ty_python_semantic/src/types/subclass_of.rs @@ -1,11 +1,12 @@ use ruff_python_ast::name::Name; +use crate::Db; use crate::place::PlaceAndQualifiers; +use crate::types::visitor::{TypeVisitor, TypeVisitorResult}; use crate::types::{ ClassType, DynamicType, KnownClass, MemberLookupPolicy, Type, TypeMapping, TypeRelation, TypeTransformer, TypeVarInstance, }; -use crate::{Db, FxOrderSet}; use super::{TypeVarBoundOrConstraints, TypeVarKind, TypeVarVariance}; @@ -16,12 +17,13 @@ pub struct SubclassOfType<'db> { subclass_of: SubclassOfInner<'db>, } -pub(super) fn walk_subclass_of_type<'db, V: super::visitor::TypeVisitor<'db> + ?Sized>( +pub(super) fn walk_subclass_of_type<'db, V: TypeVisitor<'db> + ?Sized>( db: &'db dyn Db, subclass_of: SubclassOfType<'db>, visitor: &mut V, -) { - visitor.visit_type(db, Type::from(subclass_of.subclass_of)); +) -> TypeVisitorResult { + visitor.visit_type(db, Type::from(subclass_of.subclass_of))?; + Ok(()) } impl<'db> SubclassOfType<'db> { @@ -120,19 +122,6 @@ impl<'db> SubclassOfType<'db> { } } - pub(super) fn find_legacy_typevars( - self, - db: &'db dyn Db, - typevars: &mut FxOrderSet>, - ) { - match self.subclass_of { - SubclassOfInner::Class(class) => { - class.find_legacy_typevars(db, typevars); - } - SubclassOfInner::Dynamic(_) => {} - } - } - pub(crate) fn find_name_in_mro_with_policy( self, db: &'db dyn Db, diff --git a/crates/ty_python_semantic/src/types/tuple.rs b/crates/ty_python_semantic/src/types/tuple.rs index 268bd42d1e6c0..782d6a6d63684 100644 --- a/crates/ty_python_semantic/src/types/tuple.rs +++ b/crates/ty_python_semantic/src/types/tuple.rs @@ -22,14 +22,15 @@ use std::hash::Hash; use itertools::{Either, EitherOrBoth, Itertools}; +use crate::Db; use crate::types::Truthiness; use crate::types::class::{ClassType, KnownClass}; +use crate::types::visitor::{TypeVisitor, TypeVisitorResult}; use crate::types::{ - Type, TypeMapping, TypeRelation, TypeTransformer, TypeVarInstance, TypeVarVariance, - UnionBuilder, UnionType, cyclic::PairVisitor, + Type, TypeMapping, TypeRelation, TypeTransformer, TypeVarVariance, UnionBuilder, UnionType, + cyclic::PairVisitor, }; use crate::util::subscript::{Nth, OutOfBoundsError, PyIndex, PySlice, StepSizeZeroError}; -use crate::{Db, FxOrderSet}; #[derive(Clone, Copy, Debug, Eq, PartialEq)] pub(crate) enum TupleLength { @@ -104,14 +105,15 @@ pub struct TupleType<'db> { pub(crate) tuple: TupleSpec<'db>, } -pub(super) fn walk_tuple_type<'db, V: super::visitor::TypeVisitor<'db> + ?Sized>( +pub(super) fn walk_tuple_type<'db, V: TypeVisitor<'db> + ?Sized>( db: &'db dyn Db, tuple: TupleType<'db>, visitor: &mut V, -) { +) -> TypeVisitorResult { for element in tuple.tuple(db).all_elements() { - visitor.visit_type(db, *element); + visitor.visit_type(db, *element)?; } + Ok(()) } // The Salsa heap is tracked separately. @@ -221,14 +223,6 @@ impl<'db> TupleType<'db> { TupleType::new(db, self.tuple(db).apply_type_mapping(db, type_mapping)) } - pub(crate) fn find_legacy_typevars( - self, - db: &'db dyn Db, - typevars: &mut FxOrderSet>, - ) { - self.tuple(db).find_legacy_typevars(db, typevars); - } - pub(crate) fn has_relation_to( self, db: &'db dyn Db, @@ -384,16 +378,6 @@ impl<'db> FixedLengthTuple> { ) } - fn find_legacy_typevars( - &self, - db: &'db dyn Db, - typevars: &mut FxOrderSet>, - ) { - for ty in &self.0 { - ty.find_legacy_typevars(db, typevars); - } - } - fn has_relation_to( &self, db: &'db dyn Db, @@ -722,20 +706,6 @@ impl<'db> VariableLengthTuple> { ) } - fn find_legacy_typevars( - &self, - db: &'db dyn Db, - typevars: &mut FxOrderSet>, - ) { - for ty in &self.prefix { - ty.find_legacy_typevars(db, typevars); - } - self.variable.find_legacy_typevars(db, typevars); - for ty in &self.suffix { - ty.find_legacy_typevars(db, typevars); - } - } - fn has_relation_to( &self, db: &'db dyn Db, @@ -1069,17 +1039,6 @@ impl<'db> Tuple> { } } - fn find_legacy_typevars( - &self, - db: &'db dyn Db, - typevars: &mut FxOrderSet>, - ) { - match self { - Tuple::Fixed(tuple) => tuple.find_legacy_typevars(db, typevars), - Tuple::Variable(tuple) => tuple.find_legacy_typevars(db, typevars), - } - } - fn has_relation_to(&self, db: &'db dyn Db, other: &Self, relation: TypeRelation) -> bool { match self { Tuple::Fixed(self_tuple) => self_tuple.has_relation_to(db, other, relation), diff --git a/crates/ty_python_semantic/src/types/visitor.rs b/crates/ty_python_semantic/src/types/visitor.rs index 3c9bb00a1433f..dc22ef1aa4475 100644 --- a/crates/ty_python_semantic/src/types/visitor.rs +++ b/crates/ty_python_semantic/src/types/visitor.rs @@ -16,96 +16,161 @@ use crate::{ }, }; +/// The result returned from the [`TypeVisitor`] trait methods. You can abort the visitor by +/// returning `Err(AbortTypeVisitor)`. +pub(crate) type TypeVisitorResult = Result<(), TypeVisitorControlFlow>; + +/// Controls the behavior of a [`TypeVisitor`], allowing you to skip parts of the type or abort the +/// visiting early. +pub(crate) enum TypeVisitorControlFlow { + /// Abort the entire visitor + Abort, + /// Do not recurse into the current type, but otherwise continue visiting + Prune, +} + /// A visitor trait that recurses into nested types. /// -/// The trait does not guard against infinite recursion out of the box, -/// but it makes it easy for implementors of the trait to do so. -/// See [`any_over_type`] for an example of how to do this. +/// You will typically not call the methods of this trait directly; instead, call the +/// [`visit_type`] function, which handles visiting infinitely recursive types correctly. pub(crate) trait TypeVisitor<'db> { - fn visit_type(&mut self, db: &'db dyn Db, ty: Type<'db>); + fn visit_type(&mut self, db: &'db dyn Db, ty: Type<'db>) -> TypeVisitorResult; - fn visit_union_type(&mut self, db: &'db dyn Db, union: UnionType<'db>) { - walk_union(db, union, self); + fn visit_union_type(&mut self, db: &'db dyn Db, union: UnionType<'db>) -> TypeVisitorResult { + walk_union(db, union, self) } - fn visit_intersection_type(&mut self, db: &'db dyn Db, intersection: IntersectionType<'db>) { - walk_intersection_type(db, intersection, self); + fn visit_intersection_type( + &mut self, + db: &'db dyn Db, + intersection: IntersectionType<'db>, + ) -> TypeVisitorResult { + walk_intersection_type(db, intersection, self) } - fn visit_tuple_type(&mut self, db: &'db dyn Db, tuple: TupleType<'db>) { - walk_tuple_type(db, tuple, self); + fn visit_tuple_type(&mut self, db: &'db dyn Db, tuple: TupleType<'db>) -> TypeVisitorResult { + walk_tuple_type(db, tuple, self) } - fn visit_callable_type(&mut self, db: &'db dyn Db, callable: CallableType<'db>) { - walk_callable_type(db, callable, self); + fn visit_callable_type( + &mut self, + db: &'db dyn Db, + callable: CallableType<'db>, + ) -> TypeVisitorResult { + walk_callable_type(db, callable, self) } fn visit_property_instance_type( &mut self, db: &'db dyn Db, property: PropertyInstanceType<'db>, - ) { - walk_property_instance_type(db, property, self); + ) -> TypeVisitorResult { + walk_property_instance_type(db, property, self) } - fn visit_typeis_type(&mut self, db: &'db dyn Db, type_is: TypeIsType<'db>) { - walk_typeis_type(db, type_is, self); + fn visit_typeis_type( + &mut self, + db: &'db dyn Db, + type_is: TypeIsType<'db>, + ) -> TypeVisitorResult { + walk_typeis_type(db, type_is, self) } - fn visit_subclass_of_type(&mut self, db: &'db dyn Db, subclass_of: SubclassOfType<'db>) { - walk_subclass_of_type(db, subclass_of, self); + fn visit_subclass_of_type( + &mut self, + db: &'db dyn Db, + subclass_of: SubclassOfType<'db>, + ) -> TypeVisitorResult { + walk_subclass_of_type(db, subclass_of, self) } - fn visit_generic_alias_type(&mut self, db: &'db dyn Db, alias: GenericAlias<'db>) { - walk_generic_alias(db, alias, self); + fn visit_generic_alias_type( + &mut self, + db: &'db dyn Db, + alias: GenericAlias<'db>, + ) -> TypeVisitorResult { + walk_generic_alias(db, alias, self) } - fn visit_function_type(&mut self, db: &'db dyn Db, function: FunctionType<'db>) { - walk_function_type(db, function, self); + fn visit_function_type( + &mut self, + db: &'db dyn Db, + function: FunctionType<'db>, + ) -> TypeVisitorResult { + walk_function_type(db, function, self) } - fn visit_bound_method_type(&mut self, db: &'db dyn Db, method: BoundMethodType<'db>) { - walk_bound_method_type(db, method, self); + fn visit_bound_method_type( + &mut self, + db: &'db dyn Db, + method: BoundMethodType<'db>, + ) -> TypeVisitorResult { + walk_bound_method_type(db, method, self) } - fn visit_bound_super_type(&mut self, db: &'db dyn Db, bound_super: BoundSuperType<'db>) { - walk_bound_super_type(db, bound_super, self); + fn visit_bound_super_type( + &mut self, + db: &'db dyn Db, + bound_super: BoundSuperType<'db>, + ) -> TypeVisitorResult { + walk_bound_super_type(db, bound_super, self) } - fn visit_nominal_instance_type(&mut self, db: &'db dyn Db, nominal: NominalInstanceType<'db>) { - walk_nominal_instance_type(db, nominal, self); + fn visit_nominal_instance_type( + &mut self, + db: &'db dyn Db, + nominal: NominalInstanceType<'db>, + ) -> TypeVisitorResult { + walk_nominal_instance_type(db, nominal, self) } - fn visit_type_var_type(&mut self, db: &'db dyn Db, type_var: TypeVarInstance<'db>) { - walk_type_var_type(db, type_var, self); + fn visit_type_var_type( + &mut self, + db: &'db dyn Db, + type_var: TypeVarInstance<'db>, + ) -> TypeVisitorResult { + walk_type_var_type(db, type_var, self) } fn visit_protocol_instance_type( &mut self, db: &'db dyn Db, protocol: ProtocolInstanceType<'db>, - ) { - walk_protocol_instance_type(db, protocol, self); + ) -> TypeVisitorResult { + walk_protocol_instance_type(db, protocol, self) } fn visit_method_wrapper_type( &mut self, db: &'db dyn Db, method_wrapper: MethodWrapperKind<'db>, - ) { - walk_method_wrapper_type(db, method_wrapper, self); + ) -> TypeVisitorResult { + walk_method_wrapper_type(db, method_wrapper, self) } fn visit_known_instance_type( &mut self, db: &'db dyn Db, known_instance: KnownInstanceType<'db>, - ) { - walk_known_instance_type(db, known_instance, self); + ) -> TypeVisitorResult { + walk_known_instance_type(db, known_instance, self) + } + + fn visit_type_alias_type( + &mut self, + db: &'db dyn Db, + type_alias: TypeAliasType<'db>, + ) -> TypeVisitorResult { + walk_type_alias_type(db, type_alias, self) } +} - fn visit_type_alias_type(&mut self, db: &'db dyn Db, type_alias: TypeAliasType<'db>) { - walk_type_alias_type(db, type_alias, self); +impl<'db, F> TypeVisitor<'db> for F +where + F: FnMut(&'db dyn Db, Type<'db>) -> TypeVisitorResult, +{ + fn visit_type(&mut self, db: &'db dyn Db, ty: Type<'db>) -> TypeVisitorResult { + self(db, ty) } } @@ -198,79 +263,98 @@ fn walk_non_atomic_type<'db, V: TypeVisitor<'db> + ?Sized>( db: &'db dyn Db, non_atomic_type: NonAtomicType<'db>, visitor: &mut V, -) { +) -> TypeVisitorResult { match non_atomic_type { NonAtomicType::FunctionLiteral(function) => visitor.visit_function_type(db, function), NonAtomicType::Intersection(intersection) => { - visitor.visit_intersection_type(db, intersection); + visitor.visit_intersection_type(db, intersection) } NonAtomicType::Union(union) => visitor.visit_union_type(db, union), NonAtomicType::Tuple(tuple) => visitor.visit_tuple_type(db, tuple), NonAtomicType::BoundMethod(method) => visitor.visit_bound_method_type(db, method), NonAtomicType::BoundSuper(bound_super) => visitor.visit_bound_super_type(db, bound_super), NonAtomicType::MethodWrapper(method_wrapper) => { - visitor.visit_method_wrapper_type(db, method_wrapper); + visitor.visit_method_wrapper_type(db, method_wrapper) } NonAtomicType::Callable(callable) => visitor.visit_callable_type(db, callable), NonAtomicType::GenericAlias(alias) => visitor.visit_generic_alias_type(db, alias), NonAtomicType::KnownInstance(known_instance) => { - visitor.visit_known_instance_type(db, known_instance); + visitor.visit_known_instance_type(db, known_instance) } NonAtomicType::SubclassOf(subclass_of) => visitor.visit_subclass_of_type(db, subclass_of), NonAtomicType::NominalInstance(nominal) => visitor.visit_nominal_instance_type(db, nominal), NonAtomicType::PropertyInstance(property) => { - visitor.visit_property_instance_type(db, property); + visitor.visit_property_instance_type(db, property) } NonAtomicType::TypeIs(type_is) => visitor.visit_typeis_type(db, type_is), NonAtomicType::TypeVar(type_var) => visitor.visit_type_var_type(db, type_var), NonAtomicType::ProtocolInstance(protocol) => { - visitor.visit_protocol_instance_type(db, protocol); + visitor.visit_protocol_instance_type(db, protocol) } } } -/// Return `true` if `ty`, or any of the types contained in `ty`, match the closure passed in. +/// Visits a type while guarding against infinite recursion. This lets you write a [`TypeVisitor`] +/// without having to track which types have already been seen. We guarantee that your +/// [`visit_type`][TypeVisitor::visit_type] trait method will only be called once for each distinct +/// non-atomic type that is encountered. /// -/// The function guards against infinite recursion -/// by keeping track of the non-atomic types it has already seen. -pub(super) fn any_over_type<'db>( - db: &'db dyn Db, - ty: Type<'db>, - query: &dyn Fn(Type<'db>) -> bool, -) -> bool { - struct AnyOverTypeVisitor<'db, 'a> { - query: &'a dyn Fn(Type<'db>) -> bool, +/// Note that [`TypeVisitor`] is implemented for any closure with the correct signature, meaning +/// that you often don't need to crate a named type for your visitor. +pub(super) fn visit_type<'db, V>(db: &'db dyn Db, ty: Type<'db>, wrapped: V) +where + V: TypeVisitor<'db>, +{ + struct RecursionGuard<'db, V> { + wrapped: V, seen_types: FxIndexSet>, - found_matching_type: bool, } - impl<'db> TypeVisitor<'db> for AnyOverTypeVisitor<'db, '_> { - fn visit_type(&mut self, db: &'db dyn Db, ty: Type<'db>) { - if self.found_matching_type { - return; - } - self.found_matching_type |= (self.query)(ty); - if self.found_matching_type { - return; - } - match TypeKind::from(ty) { - TypeKind::Atomic => {} - TypeKind::NonAtomic(non_atomic_type) => { - if !self.seen_types.insert(non_atomic_type) { - // If we have already seen this type, we can skip it. - return; + impl<'db, V> TypeVisitor<'db> for RecursionGuard<'db, V> + where + V: TypeVisitor<'db>, + { + fn visit_type(&mut self, db: &'db dyn Db, ty: Type<'db>) -> TypeVisitorResult { + match self.wrapped.visit_type(db, ty) { + err @ Err(TypeVisitorControlFlow::Abort) => return err, + Err(TypeVisitorControlFlow::Prune) => {} + Ok(()) => match TypeKind::from(ty) { + TypeKind::Atomic => {} + TypeKind::NonAtomic(non_atomic_type) => { + if self.seen_types.insert(non_atomic_type) { + // If we haven't already seen this type, we should recurse into it. + walk_non_atomic_type(db, non_atomic_type, self)?; + } } - walk_non_atomic_type(db, non_atomic_type, self); - } + }, } + Ok(()) } } - let mut visitor = AnyOverTypeVisitor { - query, + let mut visitor = RecursionGuard { + wrapped, seen_types: FxIndexSet::default(), - found_matching_type: false, }; - visitor.visit_type(db, ty); - visitor.found_matching_type + let _ = visitor.visit_type(db, ty); +} + +/// Return `true` if `ty`, or any of the types contained in `ty`, match the closure passed in. +pub(super) fn any_over_type<'db>( + db: &'db dyn Db, + ty: Type<'db>, + query: &'db dyn Fn(Type<'db>) -> bool, +) -> bool { + let mut found_matching_type = false; + visit_type(db, ty, |_, ty| { + if found_matching_type { + return Err(TypeVisitorControlFlow::Abort); + } + found_matching_type |= query(ty); + if found_matching_type { + return Err(TypeVisitorControlFlow::Abort); + } + Ok(()) + }); + found_matching_type }