From 46a457318d8d259376a2b458b3f814b9b795fe69 Mon Sep 17 00:00:00 2001 From: Alex Waygood Date: Wed, 4 Sep 2024 11:19:50 +0100 Subject: [PATCH] [red-knot] Add type inference for basic `for` loops (#13195) --- .../red_knot_python_semantic/src/builtins.rs | 16 -- crates/red_knot_python_semantic/src/lib.rs | 2 +- .../src/semantic_model.rs | 4 +- crates/red_knot_python_semantic/src/stdlib.rs | 77 ++++++++ crates/red_knot_python_semantic/src/types.rs | 179 ++++++++++++++---- .../src/types/builder.rs | 11 +- .../src/types/display.rs | 12 +- .../src/types/infer.rs | 161 ++++++++++------ 8 files changed, 331 insertions(+), 131 deletions(-) delete mode 100644 crates/red_knot_python_semantic/src/builtins.rs create mode 100644 crates/red_knot_python_semantic/src/stdlib.rs diff --git a/crates/red_knot_python_semantic/src/builtins.rs b/crates/red_knot_python_semantic/src/builtins.rs deleted file mode 100644 index 7695a621829f4..0000000000000 --- a/crates/red_knot_python_semantic/src/builtins.rs +++ /dev/null @@ -1,16 +0,0 @@ -use crate::module_name::ModuleName; -use crate::module_resolver::resolve_module; -use crate::semantic_index::global_scope; -use crate::semantic_index::symbol::ScopeId; -use crate::Db; - -/// Salsa query to get the builtins scope. -/// -/// Can return None if a custom typeshed is used that is missing `builtins.pyi`. -#[salsa::tracked] -pub(crate) fn builtins_scope(db: &dyn Db) -> Option> { - let builtins_name = - ModuleName::new_static("builtins").expect("Expected 'builtins' to be a valid module name"); - let builtins_file = resolve_module(db, builtins_name)?.file(); - Some(global_scope(db, builtins_file)) -} diff --git a/crates/red_knot_python_semantic/src/lib.rs b/crates/red_knot_python_semantic/src/lib.rs index 56827bcdd74ae..e5ea3dfd03f75 100644 --- a/crates/red_knot_python_semantic/src/lib.rs +++ b/crates/red_knot_python_semantic/src/lib.rs @@ -10,7 +10,6 @@ pub use python_version::PythonVersion; pub use semantic_model::{HasTy, SemanticModel}; pub mod ast_node_ref; -mod builtins; mod db; mod module_name; mod module_resolver; @@ -20,6 +19,7 @@ mod python_version; pub mod semantic_index; mod semantic_model; pub(crate) mod site_packages; +mod stdlib; pub mod types; type FxOrderSet = ordermap::set::OrderSet>; diff --git a/crates/red_knot_python_semantic/src/semantic_model.rs b/crates/red_knot_python_semantic/src/semantic_model.rs index e7320547821b6..fba9213c51948 100644 --- a/crates/red_knot_python_semantic/src/semantic_model.rs +++ b/crates/red_knot_python_semantic/src/semantic_model.rs @@ -8,7 +8,7 @@ use crate::module_name::ModuleName; use crate::module_resolver::{resolve_module, Module}; use crate::semantic_index::ast_ids::HasScopedAstId; use crate::semantic_index::semantic_index; -use crate::types::{definition_ty, global_symbol_ty_by_name, infer_scope_types, Type}; +use crate::types::{definition_ty, global_symbol_ty, infer_scope_types, Type}; use crate::Db; pub struct SemanticModel<'db> { @@ -40,7 +40,7 @@ impl<'db> SemanticModel<'db> { } pub fn global_symbol_ty(&self, module: &Module, symbol_name: &str) -> Type<'db> { - global_symbol_ty_by_name(self.db, module.file(), symbol_name) + global_symbol_ty(self.db, module.file(), symbol_name) } } diff --git a/crates/red_knot_python_semantic/src/stdlib.rs b/crates/red_knot_python_semantic/src/stdlib.rs new file mode 100644 index 0000000000000..b80cf4d71ecb0 --- /dev/null +++ b/crates/red_knot_python_semantic/src/stdlib.rs @@ -0,0 +1,77 @@ +use crate::module_name::ModuleName; +use crate::module_resolver::resolve_module; +use crate::semantic_index::global_scope; +use crate::semantic_index::symbol::ScopeId; +use crate::types::{global_symbol_ty, Type}; +use crate::Db; + +/// Enumeration of various core stdlib modules, for which we have dedicated Salsa queries. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum CoreStdlibModule { + Builtins, + Types, + Typeshed, +} + +impl CoreStdlibModule { + fn name(self) -> ModuleName { + let module_name = match self { + Self::Builtins => "builtins", + Self::Types => "types", + Self::Typeshed => "_typeshed", + }; + ModuleName::new_static(module_name) + .unwrap_or_else(|| panic!("{module_name} should be a valid module name!")) + } +} + +/// Lookup the type of `symbol` in a given core module +/// +/// Returns `Unbound` if the given core module cannot be resolved for some reason +fn core_module_symbol_ty<'db>( + db: &'db dyn Db, + core_module: CoreStdlibModule, + symbol: &str, +) -> Type<'db> { + resolve_module(db, core_module.name()) + .map(|module| global_symbol_ty(db, module.file(), symbol)) + .unwrap_or(Type::Unbound) +} + +/// Lookup the type of `symbol` in the builtins namespace. +/// +/// Returns `Unbound` if the `builtins` module isn't available for some reason. +#[inline] +pub(crate) fn builtins_symbol_ty<'db>(db: &'db dyn Db, symbol: &str) -> Type<'db> { + core_module_symbol_ty(db, CoreStdlibModule::Builtins, symbol) +} + +/// Lookup the type of `symbol` in the `types` module namespace. +/// +/// Returns `Unbound` if the `types` module isn't available for some reason. +#[inline] +pub(crate) fn types_symbol_ty<'db>(db: &'db dyn Db, symbol: &str) -> Type<'db> { + core_module_symbol_ty(db, CoreStdlibModule::Types, symbol) +} + +/// Lookup the type of `symbol` in the `_typeshed` module namespace. +/// +/// Returns `Unbound` if the `_typeshed` module isn't available for some reason. +#[inline] +pub(crate) fn typeshed_symbol_ty<'db>(db: &'db dyn Db, symbol: &str) -> Type<'db> { + core_module_symbol_ty(db, CoreStdlibModule::Typeshed, symbol) +} + +/// Get the scope of a core stdlib module. +/// +/// Can return `None` if a custom typeshed is used that is missing the core module in question. +fn core_module_scope(db: &dyn Db, core_module: CoreStdlibModule) -> Option> { + resolve_module(db, core_module.name()).map(|module| global_scope(db, module.file())) +} + +/// Get the `builtins` module scope. +/// +/// Can return `None` if a custom typeshed is used that is missing `builtins.pyi`. +pub(crate) fn builtins_module_scope(db: &dyn Db) -> Option> { + core_module_scope(db, CoreStdlibModule::Builtins) +} diff --git a/crates/red_knot_python_semantic/src/types.rs b/crates/red_knot_python_semantic/src/types.rs index e39b82b9fe0ce..acb7c480259d2 100644 --- a/crates/red_knot_python_semantic/src/types.rs +++ b/crates/red_knot_python_semantic/src/types.rs @@ -1,7 +1,6 @@ use ruff_db::files::File; use ruff_python_ast as ast; -use crate::builtins::builtins_scope; use crate::semantic_index::ast_ids::HasScopedAstId; use crate::semantic_index::definition::{Definition, DefinitionKind}; use crate::semantic_index::symbol::{ScopeId, ScopedSymbolId}; @@ -9,6 +8,7 @@ use crate::semantic_index::{ global_scope, semantic_index, symbol_table, use_def_map, DefinitionWithConstraints, DefinitionWithConstraintsIterator, }; +use crate::stdlib::{builtins_symbol_ty, types_symbol_ty, typeshed_symbol_ty}; use crate::types::narrow::narrowing_constraint; use crate::{Db, FxOrderSet}; @@ -40,7 +40,7 @@ pub fn check_types(db: &dyn Db, file: File) -> TypeCheckDiagnostics { } /// Infer the public type of a symbol (its type as seen from outside its scope). -pub(crate) fn symbol_ty<'db>( +pub(crate) fn symbol_ty_by_id<'db>( db: &'db dyn Db, scope: ScopeId<'db>, symbol: ScopedSymbolId, @@ -58,30 +58,17 @@ pub(crate) fn symbol_ty<'db>( } /// Shorthand for `symbol_ty` that takes a symbol name instead of an ID. -pub(crate) fn symbol_ty_by_name<'db>( - db: &'db dyn Db, - scope: ScopeId<'db>, - name: &str, -) -> Type<'db> { +pub(crate) fn symbol_ty<'db>(db: &'db dyn Db, scope: ScopeId<'db>, name: &str) -> Type<'db> { let table = symbol_table(db, scope); table .symbol_id_by_name(name) - .map(|symbol| symbol_ty(db, scope, symbol)) + .map(|symbol| symbol_ty_by_id(db, scope, symbol)) .unwrap_or(Type::Unbound) } /// Shorthand for `symbol_ty` that looks up a module-global symbol by name in a file. -pub(crate) fn global_symbol_ty_by_name<'db>(db: &'db dyn Db, file: File, name: &str) -> Type<'db> { - symbol_ty_by_name(db, global_scope(db, file), name) -} - -/// Shorthand for `symbol_ty` that looks up a symbol in the builtins. -/// -/// Returns `Unbound` if the builtins module isn't available for some reason. -pub(crate) fn builtins_symbol_ty_by_name<'db>(db: &'db dyn Db, name: &str) -> Type<'db> { - builtins_scope(db) - .map(|builtins| symbol_ty_by_name(db, builtins, name)) - .unwrap_or(Type::Unbound) +pub(crate) fn global_symbol_ty<'db>(db: &'db dyn Db, file: File, name: &str) -> Type<'db> { + symbol_ty(db, global_scope(db, file), name) } /// Infer the type of a [`Definition`]. @@ -306,13 +293,9 @@ impl<'db> Type<'db> { pub fn replace_unbound_with(&self, db: &'db dyn Db, replacement: Type<'db>) -> Type<'db> { match self { Type::Unbound => replacement, - Type::Union(union) => union - .elements(db) - .into_iter() - .fold(UnionBuilder::new(db), |builder, ty| { - builder.add(ty.replace_unbound_with(db, replacement)) - }) - .build(), + Type::Union(union) => { + union.map(db, |element| element.replace_unbound_with(db, replacement)) + } ty => *ty, } } @@ -331,7 +314,7 @@ impl<'db> Type<'db> { /// us to explicitly consider whether to handle an error or propagate /// it up the call stack. #[must_use] - pub fn member(&self, db: &'db dyn Db, name: &ast::name::Name) -> Type<'db> { + pub fn member(&self, db: &'db dyn Db, name: &str) -> Type<'db> { match self { Type::Any => Type::Any, Type::Never => { @@ -348,19 +331,13 @@ impl<'db> Type<'db> { // TODO: attribute lookup on function type Type::Unknown } - Type::Module(file) => global_symbol_ty_by_name(db, *file, name), + Type::Module(file) => global_symbol_ty(db, *file, name), Type::Class(class) => class.class_member(db, name), Type::Instance(_) => { // TODO MRO? get_own_instance_member, get_instance_member Type::Unknown } - Type::Union(union) => union - .elements(db) - .iter() - .fold(UnionBuilder::new(db), |builder, element_ty| { - builder.add(element_ty.member(db, name)) - }) - .build(), + Type::Union(union) => union.map(db, |element| element.member(db, name)), Type::Intersection(_) => { // TODO perform the get_member on each type in the intersection // TODO return the intersection of those results @@ -415,6 +392,38 @@ impl<'db> Type<'db> { } } + /// Given the type of an object that is iterated over in some way, + /// return the type of objects that are yielded by that iteration. + /// + /// E.g., for the following loop, given the type of `x`, infer the type of `y`: + /// ```python + /// for y in x: + /// pass + /// ``` + /// + /// Returns `None` if `self` represents a type that is not iterable. + fn iterate(&self, db: &'db dyn Db) -> Option> { + // `self` represents the type of the iterable; + // `__iter__` and `__next__` are both looked up on the class of the iterable: + let type_of_class = self.to_meta_type(db); + + let dunder_iter_method = type_of_class.member(db, "__iter__"); + if !dunder_iter_method.is_unbound() { + let iterator_ty = dunder_iter_method.call(db)?; + let dunder_next_method = iterator_ty.to_meta_type(db).member(db, "__next__"); + return dunder_next_method.call(db); + } + + // Although it's not considered great practice, + // classes that define `__getitem__` are also iterable, + // even if they do not define `__iter__`. + // + // TODO this is only valid if the `__getitem__` method is annotated as + // accepting `int` or `SupportsIndex` + let dunder_get_item_method = type_of_class.member(db, "__getitem__"); + dunder_get_item_method.call(db) + } + #[must_use] pub fn to_instance(&self) -> Type<'db> { match self { @@ -424,6 +433,34 @@ impl<'db> Type<'db> { _ => Type::Unknown, // TODO type errors } } + + /// Given a type that is assumed to represent an instance of a class, + /// return a type that represents that class itself. + #[must_use] + pub fn to_meta_type(&self, db: &'db dyn Db) -> Type<'db> { + match self { + Type::Unbound => Type::Unbound, + Type::Never => Type::Never, + Type::Instance(class) => Type::Class(*class), + Type::Union(union) => union.map(db, |ty| ty.to_meta_type(db)), + Type::BooleanLiteral(_) => builtins_symbol_ty(db, "bool"), + Type::BytesLiteral(_) => builtins_symbol_ty(db, "bytes"), + Type::IntLiteral(_) => builtins_symbol_ty(db, "int"), + Type::Function(_) => types_symbol_ty(db, "FunctionType"), + Type::Module(_) => types_symbol_ty(db, "ModuleType"), + Type::None => typeshed_symbol_ty(db, "NoneType"), + // TODO not accurate if there's a custom metaclass... + Type::Class(_) => builtins_symbol_ty(db, "type"), + // TODO can we do better here? `type[LiteralString]`? + Type::StringLiteral(_) | Type::LiteralString => builtins_symbol_ty(db, "str"), + // TODO: `type[Any]`? + Type::Any => Type::Any, + // TODO: `type[Unknown]`? + Type::Unknown => Type::Unknown, + // TODO intersections + Type::Intersection(_) => Type::Unknown, + } + } } #[salsa::interned] @@ -504,7 +541,7 @@ impl<'db> ClassType<'db> { /// Returns the class member of this class named `name`. /// /// The member resolves to a member of the class itself or any of its bases. - pub fn class_member(self, db: &'db dyn Db, name: &ast::name::Name) -> Type<'db> { + pub fn class_member(self, db: &'db dyn Db, name: &str) -> Type<'db> { let member = self.own_class_member(db, name); if !member.is_unbound() { return member; @@ -514,12 +551,12 @@ impl<'db> ClassType<'db> { } /// Returns the inferred type of the class member named `name`. - pub fn own_class_member(self, db: &'db dyn Db, name: &ast::name::Name) -> Type<'db> { + pub fn own_class_member(self, db: &'db dyn Db, name: &str) -> Type<'db> { let scope = self.body_scope(db); - symbol_ty_by_name(db, scope, name) + symbol_ty(db, scope, name) } - pub fn inherited_class_member(self, db: &'db dyn Db, name: &ast::name::Name) -> Type<'db> { + pub fn inherited_class_member(self, db: &'db dyn Db, name: &str) -> Type<'db> { for base in self.bases(db) { let member = base.member(db, name); if !member.is_unbound() { @@ -542,6 +579,21 @@ impl<'db> UnionType<'db> { pub fn contains(&self, db: &'db dyn Db, ty: Type<'db>) -> bool { self.elements(db).contains(&ty) } + + /// Apply a transformation function to all elements of the union, + /// and create a new union from the resulting set of types + pub fn map( + &self, + db: &'db dyn Db, + mut transform_fn: impl FnMut(&Type<'db>) -> Type<'db>, + ) -> Type<'db> { + self.elements(db) + .into_iter() + .fold(UnionBuilder::new(db), |builder, element| { + builder.add(transform_fn(element)) + }) + .build() + } } #[salsa::interned] @@ -688,4 +740,53 @@ mod tests { &["Object of type 'Literal[123]' is not callable"], ); } + + #[test] + fn invalid_iterable() { + let mut db = setup_db(); + + db.write_dedented( + "src/a.py", + " + nonsense = 123 + for x in nonsense: + pass + ", + ) + .unwrap(); + + let a_file = system_path_to_file(&db, "/src/a.py").unwrap(); + let a_file_diagnostics = super::check_types(&db, a_file); + assert_diagnostic_messages( + &a_file_diagnostics, + &["Object of type 'Literal[123]' is not iterable"], + ); + } + + #[test] + fn new_iteration_protocol_takes_precedence_over_old_style() { + let mut db = setup_db(); + + db.write_dedented( + "src/a.py", + " + class NotIterable: + def __getitem__(self, key: int) -> int: + return 42 + + __iter__ = None + + for x in NotIterable(): + pass + ", + ) + .unwrap(); + + let a_file = system_path_to_file(&db, "/src/a.py").unwrap(); + let a_file_diagnostics = super::check_types(&db, a_file); + assert_diagnostic_messages( + &a_file_diagnostics, + &["Object of type 'NotIterable' is not iterable"], + ); + } } diff --git a/crates/red_knot_python_semantic/src/types/builder.rs b/crates/red_knot_python_semantic/src/types/builder.rs index 0ced308f0d7af..c461459f059bb 100644 --- a/crates/red_knot_python_semantic/src/types/builder.rs +++ b/crates/red_knot_python_semantic/src/types/builder.rs @@ -25,13 +25,10 @@ //! * No type in an intersection can be a supertype of any other type in the intersection (just //! eliminate the supertype from the intersection). //! * An intersection containing two non-overlapping types should simplify to [`Type::Never`]. - -use crate::types::{IntersectionType, Type, UnionType}; +use crate::types::{builtins_symbol_ty, IntersectionType, Type, UnionType}; use crate::{Db, FxOrderSet}; use ordermap::set::MutableValues; -use super::builtins_symbol_ty_by_name; - pub(crate) struct UnionBuilder<'db> { elements: FxOrderSet>, db: &'db dyn Db, @@ -68,7 +65,7 @@ impl<'db> UnionBuilder<'db> { if let Some(true_index) = self.elements.get_index_of(&Type::BooleanLiteral(true)) { if self.elements.contains(&Type::BooleanLiteral(false)) { *self.elements.get_index_mut2(true_index).unwrap() = - builtins_symbol_ty_by_name(self.db, "bool"); + builtins_symbol_ty(self.db, "bool"); self.elements.remove(&Type::BooleanLiteral(false)); } } @@ -278,7 +275,7 @@ mod tests { use crate::db::tests::TestDb; use crate::program::{Program, SearchPathSettings}; use crate::python_version::PythonVersion; - use crate::types::builtins_symbol_ty_by_name; + use crate::types::builtins_symbol_ty; use crate::ProgramSettings; use ruff_db::system::{DbWithTestSystem, SystemPathBuf}; @@ -351,7 +348,7 @@ mod tests { #[test] fn build_union_bool() { let db = setup_db(); - let bool_ty = builtins_symbol_ty_by_name(&db, "bool"); + let bool_ty = builtins_symbol_ty(&db, "bool"); let t0 = Type::BooleanLiteral(true); let t1 = Type::BooleanLiteral(true); diff --git a/crates/red_knot_python_semantic/src/types/display.rs b/crates/red_knot_python_semantic/src/types/display.rs index 8b27f1a197e12..49241154994a6 100644 --- a/crates/red_knot_python_semantic/src/types/display.rs +++ b/crates/red_knot_python_semantic/src/types/display.rs @@ -236,9 +236,7 @@ mod tests { use ruff_db::system::{DbWithTestSystem, SystemPathBuf}; use crate::db::tests::TestDb; - use crate::types::{ - global_symbol_ty_by_name, BytesLiteralType, StringLiteralType, Type, UnionBuilder, - }; + use crate::types::{global_symbol_ty, BytesLiteralType, StringLiteralType, Type, UnionBuilder}; use crate::{Program, ProgramSettings, PythonVersion, SearchPathSettings}; fn setup_db() -> TestDb { @@ -283,16 +281,16 @@ mod tests { let vec: Vec> = vec![ Type::Unknown, Type::IntLiteral(-1), - global_symbol_ty_by_name(&db, mod_file, "A"), + global_symbol_ty(&db, mod_file, "A"), Type::StringLiteral(StringLiteralType::new(&db, Box::from("A"))), Type::BytesLiteral(BytesLiteralType::new(&db, Box::from([0]))), Type::BytesLiteral(BytesLiteralType::new(&db, Box::from([7]))), Type::IntLiteral(0), Type::IntLiteral(1), Type::StringLiteral(StringLiteralType::new(&db, Box::from("B"))), - global_symbol_ty_by_name(&db, mod_file, "foo"), - global_symbol_ty_by_name(&db, mod_file, "bar"), - global_symbol_ty_by_name(&db, mod_file, "B"), + global_symbol_ty(&db, mod_file, "foo"), + global_symbol_ty(&db, mod_file, "bar"), + global_symbol_ty(&db, mod_file, "B"), Type::BooleanLiteral(true), Type::None, ]; diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index b9b8f900b731e..98a038afb0c6d 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -37,7 +37,6 @@ use ruff_db::parsed::parsed_module; use ruff_python_ast::{self as ast, AnyNodeRef, ExprContext, UnaryOp}; use ruff_text_size::Ranged; -use crate::builtins::builtins_scope; use crate::module_name::ModuleName; use crate::module_resolver::{file_to_module, resolve_module}; use crate::semantic_index::ast_ids::{HasScopedAstId, HasScopedUseId, ScopedExpressionId}; @@ -46,11 +45,11 @@ use crate::semantic_index::expression::Expression; use crate::semantic_index::semantic_index; use crate::semantic_index::symbol::{NodeWithScopeKind, NodeWithScopeRef, ScopeId}; use crate::semantic_index::SemanticIndex; +use crate::stdlib::builtins_module_scope; use crate::types::diagnostic::{TypeCheckDiagnostic, TypeCheckDiagnostics}; use crate::types::{ - builtins_symbol_ty_by_name, definitions_ty, global_symbol_ty_by_name, symbol_ty, - symbol_ty_by_name, BytesLiteralType, ClassType, FunctionType, StringLiteralType, Type, - UnionBuilder, + builtins_symbol_ty, definitions_ty, global_symbol_ty, symbol_ty, symbol_ty_by_id, + BytesLiteralType, ClassType, FunctionType, StringLiteralType, Type, UnionBuilder, }; use crate::Db; @@ -1043,18 +1042,17 @@ impl<'db> TypeInferenceBuilder<'db> { .types .expression_ty(iterable.scoped_ast_id(self.db, self.scope)); - // TODO(Alex): only a valid iterable if the *type* of `iterable_ty` has an `__iter__` - // member (dunders are never looked up on an instance) - let _dunder_iter_ty = iterable_ty.member(self.db, &ast::name::Name::from("__iter__")); - - // TODO(Alex): - // - infer the return type of the `__iter__` method, which gives us the iterator - // - lookup the `__next__` method on the iterator - // - infer the return type of the iterator's `__next__` method, - // which gives us the type of the variable being bound here - // (...or the type of the object being unpacked into multiple definitions, if it's something like - // `for k, v in d.items(): ...`) - let loop_var_value_ty = Type::Unknown; + let loop_var_value_ty = iterable_ty.iterate(self.db).unwrap_or_else(|| { + self.add_diagnostic( + iterable.into(), + "not-iterable", + format_args!( + "Object of type '{}' is not iterable", + iterable_ty.display(self.db) + ), + ); + Type::Unknown + }); self.types .expressions @@ -1400,11 +1398,9 @@ impl<'db> TypeInferenceBuilder<'db> { ast::Number::Int(n) => n .as_i64() .map(Type::IntLiteral) - .unwrap_or_else(|| builtins_symbol_ty_by_name(self.db, "int").to_instance()), - ast::Number::Float(_) => builtins_symbol_ty_by_name(self.db, "float").to_instance(), - ast::Number::Complex { .. } => { - builtins_symbol_ty_by_name(self.db, "complex").to_instance() - } + .unwrap_or_else(|| builtins_symbol_ty(self.db, "int").to_instance()), + ast::Number::Float(_) => builtins_symbol_ty(self.db, "float").to_instance(), + ast::Number::Complex { .. } => builtins_symbol_ty(self.db, "complex").to_instance(), } } @@ -1482,12 +1478,11 @@ impl<'db> TypeInferenceBuilder<'db> { } } - #[allow(clippy::unused_self)] fn infer_ellipsis_literal_expression( &mut self, _literal: &ast::ExprEllipsisLiteral, ) -> Type<'db> { - builtins_symbol_ty_by_name(self.db, "Ellipsis") + builtins_symbol_ty(self.db, "Ellipsis") } fn infer_tuple_expression(&mut self, tuple: &ast::ExprTuple) -> Type<'db> { @@ -1503,7 +1498,7 @@ impl<'db> TypeInferenceBuilder<'db> { } // TODO generic - builtins_symbol_ty_by_name(self.db, "tuple").to_instance() + builtins_symbol_ty(self.db, "tuple").to_instance() } fn infer_list_expression(&mut self, list: &ast::ExprList) -> Type<'db> { @@ -1518,7 +1513,7 @@ impl<'db> TypeInferenceBuilder<'db> { } // TODO generic - builtins_symbol_ty_by_name(self.db, "list").to_instance() + builtins_symbol_ty(self.db, "list").to_instance() } fn infer_set_expression(&mut self, set: &ast::ExprSet) -> Type<'db> { @@ -1529,7 +1524,7 @@ impl<'db> TypeInferenceBuilder<'db> { } // TODO generic - builtins_symbol_ty_by_name(self.db, "set").to_instance() + builtins_symbol_ty(self.db, "set").to_instance() } fn infer_dict_expression(&mut self, dict: &ast::ExprDict) -> Type<'db> { @@ -1541,7 +1536,7 @@ impl<'db> TypeInferenceBuilder<'db> { } // TODO generic - builtins_symbol_ty_by_name(self.db, "dict").to_instance() + builtins_symbol_ty(self.db, "dict").to_instance() } /// Infer the type of the `iter` expression of the first comprehension. @@ -1884,7 +1879,7 @@ impl<'db> TypeInferenceBuilder<'db> { // runtime, it is the scope that creates the cell for our closure.) If the name // isn't bound in that scope, we should get an unbound name, not continue // falling back to other scopes / globals / builtins. - return symbol_ty_by_name(self.db, enclosing_scope_id, name); + return symbol_ty(self.db, enclosing_scope_id, name); } } // No nonlocal binding, check module globals. Avoid infinite recursion if `self.scope` @@ -1892,11 +1887,11 @@ impl<'db> TypeInferenceBuilder<'db> { let ty = if file_scope_id.is_global() { Type::Unbound } else { - global_symbol_ty_by_name(self.db, self.file, name) + global_symbol_ty(self.db, self.file, name) }; // Fallback to builtins (without infinite recursion if we're already in builtins.) - if ty.may_be_unbound(self.db) && Some(self.scope) != builtins_scope(self.db) { - ty.replace_unbound_with(self.db, builtins_symbol_ty_by_name(self.db, name)) + if ty.may_be_unbound(self.db) && Some(self.scope) != builtins_module_scope(self.db) { + ty.replace_unbound_with(self.db, builtins_symbol_ty(self.db, name)) } else { ty } @@ -1915,7 +1910,7 @@ impl<'db> TypeInferenceBuilder<'db> { let symbol = symbols .symbol_id_by_name(id) .expect("Expected the symbol table to create a symbol for every Name node"); - return symbol_ty(self.db, self.scope, symbol); + return symbol_ty_by_id(self.db, self.scope, symbol); } match ctx { @@ -1986,22 +1981,22 @@ impl<'db> TypeInferenceBuilder<'db> { (Type::IntLiteral(n), Type::IntLiteral(m), ast::Operator::Add) => n .checked_add(m) .map(Type::IntLiteral) - .unwrap_or_else(|| builtins_symbol_ty_by_name(self.db, "int").to_instance()), + .unwrap_or_else(|| builtins_symbol_ty(self.db, "int").to_instance()), (Type::IntLiteral(n), Type::IntLiteral(m), ast::Operator::Sub) => n .checked_sub(m) .map(Type::IntLiteral) - .unwrap_or_else(|| builtins_symbol_ty_by_name(self.db, "int").to_instance()), + .unwrap_or_else(|| builtins_symbol_ty(self.db, "int").to_instance()), (Type::IntLiteral(n), Type::IntLiteral(m), ast::Operator::Mult) => n .checked_mul(m) .map(Type::IntLiteral) - .unwrap_or_else(|| builtins_symbol_ty_by_name(self.db, "int").to_instance()), + .unwrap_or_else(|| builtins_symbol_ty(self.db, "int").to_instance()), (Type::IntLiteral(n), Type::IntLiteral(m), ast::Operator::Div) => n .checked_div(m) .map(Type::IntLiteral) - .unwrap_or_else(|| builtins_symbol_ty_by_name(self.db, "int").to_instance()), + .unwrap_or_else(|| builtins_symbol_ty(self.db, "int").to_instance()), (Type::IntLiteral(n), Type::IntLiteral(m), ast::Operator::Mod) => n .checked_rem(m) @@ -2380,14 +2375,14 @@ mod tests { use ruff_db::testing::assert_function_query_was_not_run; use ruff_python_ast::name::Name; - use crate::builtins::builtins_scope; use crate::db::tests::TestDb; use crate::program::{Program, SearchPathSettings}; use crate::python_version::PythonVersion; use crate::semantic_index::definition::Definition; use crate::semantic_index::symbol::FileScopeId; use crate::semantic_index::{global_scope, semantic_index, symbol_table, use_def_map}; - use crate::types::{global_symbol_ty_by_name, infer_definition_types, symbol_ty_by_name}; + use crate::stdlib::builtins_module_scope; + use crate::types::{global_symbol_ty, infer_definition_types, symbol_ty}; use crate::{HasTy, ProgramSettings, SemanticModel}; use super::TypeInferenceBuilder; @@ -2440,7 +2435,7 @@ mod tests { fn assert_public_ty(db: &TestDb, file_name: &str, symbol_name: &str, expected: &str) { let file = system_path_to_file(db, file_name).expect("Expected file to exist."); - let ty = global_symbol_ty_by_name(db, file, symbol_name); + let ty = global_symbol_ty(db, file, symbol_name); assert_eq!(ty.display(db).to_string(), expected); } @@ -2465,7 +2460,7 @@ mod tests { assert_eq!(scope.name(db), *expected_scope_name); } - let ty = symbol_ty_by_name(db, scope, symbol_name); + let ty = symbol_ty(db, scope, symbol_name); assert_eq!(ty.display(db).to_string(), expected); } @@ -2669,7 +2664,7 @@ mod tests { )?; let mod_file = system_path_to_file(&db, "src/mod.py").expect("Expected file to exist."); - let ty = global_symbol_ty_by_name(&db, mod_file, "Sub"); + let ty = global_symbol_ty(&db, mod_file, "Sub"); let class = ty.expect_class(); @@ -2696,7 +2691,7 @@ mod tests { )?; let mod_file = system_path_to_file(&db, "src/mod.py").unwrap(); - let ty = global_symbol_ty_by_name(&db, mod_file, "C"); + let ty = global_symbol_ty(&db, mod_file, "C"); let class_id = ty.expect_class(); let member_ty = class_id.class_member(&db, &Name::new_static("f")); let func = member_ty.expect_function(); @@ -2900,7 +2895,7 @@ mod tests { db.write_file("src/a.py", "def example() -> int: return 42")?; let mod_file = system_path_to_file(&db, "src/a.py").unwrap(); - let function = global_symbol_ty_by_name(&db, mod_file, "example").expect_function(); + let function = global_symbol_ty(&db, mod_file, "example").expect_function(); let returns = function.return_type(&db); assert_eq!(returns.display(&db).to_string(), "int"); @@ -2975,6 +2970,52 @@ mod tests { Ok(()) } + #[test] + fn basic_for_loop() -> anyhow::Result<()> { + let mut db = setup_db(); + + db.write_dedented( + "src/a.py", + " + class IntIterator: + def __next__(self) -> int: + return 42 + + class IntIterable: + def __iter__(self) -> IntIterator: + return IntIterator() + + for x in IntIterable(): + pass + ", + )?; + + assert_public_ty(&db, "src/a.py", "x", "int"); + + Ok(()) + } + + #[test] + fn for_loop_with_old_style_iteration_protocol() -> anyhow::Result<()> { + let mut db = setup_db(); + + db.write_dedented( + "src/a.py", + " + class OldStyleIterable: + def __getitem__(self, key: int) -> int: + return 42 + + for x in OldStyleIterable(): + pass + ", + )?; + + assert_public_ty(&db, "src/a.py", "x", "int"); + + Ok(()) + } + #[test] fn class_constructor_call_expression() -> anyhow::Result<()> { let mut db = setup_db(); @@ -3317,7 +3358,7 @@ mod tests { )?; let a = system_path_to_file(&db, "src/a.py").expect("Expected file to exist."); - let c_ty = global_symbol_ty_by_name(&db, a, "C"); + let c_ty = global_symbol_ty(&db, a, "C"); let c_class = c_ty.expect_class(); let mut c_bases = c_class.bases(&db); let b_ty = c_bases.next().unwrap(); @@ -3354,8 +3395,8 @@ mod tests { .unwrap() .0 .to_scope_id(&db, file); - let y_ty = symbol_ty_by_name(&db, function_scope, "y"); - let x_ty = symbol_ty_by_name(&db, function_scope, "x"); + let y_ty = symbol_ty(&db, function_scope, "y"); + let x_ty = symbol_ty(&db, function_scope, "x"); assert_eq!(y_ty.display(&db).to_string(), "Unbound"); assert_eq!(x_ty.display(&db).to_string(), "Literal[2]"); @@ -3385,8 +3426,8 @@ mod tests { .unwrap() .0 .to_scope_id(&db, file); - let y_ty = symbol_ty_by_name(&db, function_scope, "y"); - let x_ty = symbol_ty_by_name(&db, function_scope, "x"); + let y_ty = symbol_ty(&db, function_scope, "y"); + let x_ty = symbol_ty(&db, function_scope, "x"); assert_eq!(x_ty.display(&db).to_string(), "Unbound"); assert_eq!(y_ty.display(&db).to_string(), "Literal[1]"); @@ -3416,7 +3457,7 @@ mod tests { .unwrap() .0 .to_scope_id(&db, file); - let y_ty = symbol_ty_by_name(&db, function_scope, "y"); + let y_ty = symbol_ty(&db, function_scope, "y"); assert_eq!( y_ty.display(&db).to_string(), @@ -3450,8 +3491,8 @@ mod tests { .unwrap() .0 .to_scope_id(&db, file); - let y_ty = symbol_ty_by_name(&db, class_scope, "y"); - let x_ty = symbol_ty_by_name(&db, class_scope, "x"); + let y_ty = symbol_ty(&db, class_scope, "y"); + let x_ty = symbol_ty(&db, class_scope, "x"); assert_eq!(x_ty.display(&db).to_string(), "Unbound | Literal[2]"); assert_eq!(y_ty.display(&db).to_string(), "Literal[1]"); @@ -3544,9 +3585,11 @@ mod tests { assert_public_ty(&db, "/src/a.py", "x", "Literal[copyright]"); // imported builtins module is the same file as the implicit builtins let file = system_path_to_file(&db, "/src/a.py").expect("Expected file to exist."); - let builtins_ty = global_symbol_ty_by_name(&db, file, "builtins"); + let builtins_ty = global_symbol_ty(&db, file, "builtins"); let builtins_file = builtins_ty.expect_module(); - let implicit_builtins_file = builtins_scope(&db).expect("builtins to exist").file(&db); + let implicit_builtins_file = builtins_module_scope(&db) + .expect("builtins module should exist") + .file(&db); assert_eq!(builtins_file, implicit_builtins_file); Ok(()) @@ -3850,7 +3893,7 @@ mod tests { ])?; let a = system_path_to_file(&db, "/src/a.py").unwrap(); - let x_ty = global_symbol_ty_by_name(&db, a, "x"); + let x_ty = global_symbol_ty(&db, a, "x"); assert_eq!(x_ty.display(&db).to_string(), "Literal[10]"); @@ -3859,7 +3902,7 @@ mod tests { let a = system_path_to_file(&db, "/src/a.py").unwrap(); - let x_ty_2 = global_symbol_ty_by_name(&db, a, "x"); + let x_ty_2 = global_symbol_ty(&db, a, "x"); assert_eq!(x_ty_2.display(&db).to_string(), "Literal[20]"); @@ -3876,7 +3919,7 @@ mod tests { ])?; let a = system_path_to_file(&db, "/src/a.py").unwrap(); - let x_ty = global_symbol_ty_by_name(&db, a, "x"); + let x_ty = global_symbol_ty(&db, a, "x"); assert_eq!(x_ty.display(&db).to_string(), "Literal[10]"); @@ -3886,7 +3929,7 @@ mod tests { db.clear_salsa_events(); - let x_ty_2 = global_symbol_ty_by_name(&db, a, "x"); + let x_ty_2 = global_symbol_ty(&db, a, "x"); assert_eq!(x_ty_2.display(&db).to_string(), "Literal[10]"); @@ -3912,7 +3955,7 @@ mod tests { ])?; let a = system_path_to_file(&db, "/src/a.py").unwrap(); - let x_ty = global_symbol_ty_by_name(&db, a, "x"); + let x_ty = global_symbol_ty(&db, a, "x"); assert_eq!(x_ty.display(&db).to_string(), "Literal[10]"); @@ -3922,7 +3965,7 @@ mod tests { db.clear_salsa_events(); - let x_ty_2 = global_symbol_ty_by_name(&db, a, "x"); + let x_ty_2 = global_symbol_ty(&db, a, "x"); assert_eq!(x_ty_2.display(&db).to_string(), "Literal[10]");