From 8144a11f98032a1f109f557fbd5c8379364230b6 Mon Sep 17 00:00:00 2001 From: Dhruv Manilawala Date: Thu, 22 Aug 2024 08:00:19 +0530 Subject: [PATCH] [red-knot] Add definition for with items (#12920) ## Summary This PR adds symbols and definitions introduced by `with` statements. The symbols and definitions are introduced for each with item. The type inference is updated to call the definition region type inference instead. ## Test Plan Add test case to check for symbol table and definitions. --- .../src/semantic_index.rs | 50 +++++++++++++++++++ .../src/semantic_index/builder.rs | 30 +++++++++++ .../src/semantic_index/definition.rs | 38 +++++++++++++- .../src/types/infer.rs | 36 ++++++++++++- 4 files changed, 151 insertions(+), 3 deletions(-) diff --git a/crates/red_knot_python_semantic/src/semantic_index.rs b/crates/red_knot_python_semantic/src/semantic_index.rs index 0c5942f05f4d8..d7603ddef387a 100644 --- a/crates/red_knot_python_semantic/src/semantic_index.rs +++ b/crates/red_knot_python_semantic/src/semantic_index.rs @@ -790,6 +790,56 @@ def f(a: str, /, b: str, c: int = 1, *args, d: int = 2, **kwargs): assert_eq!(names(&inner_comprehension_symbol_table), vec!["x"]); } + #[test] + fn with_item_definition() { + let TestCase { db, file } = test_case( + " +with item1 as x, item2 as y: + pass +", + ); + + let index = semantic_index(&db, file); + let global_table = index.symbol_table(FileScopeId::global()); + + assert_eq!(names(&global_table), vec!["item1", "x", "item2", "y"]); + + let use_def = index.use_def_map(FileScopeId::global()); + for name in ["x", "y"] { + let Some(definition) = use_def.first_public_definition( + global_table.symbol_id_by_name(name).expect("symbol exists"), + ) else { + panic!("Expected with item definition for {name}"); + }; + assert!(matches!(definition.node(&db), DefinitionKind::WithItem(_))); + } + } + + #[test] + fn with_item_unpacked_definition() { + let TestCase { db, file } = test_case( + " +with context() as (x, y): + pass +", + ); + + let index = semantic_index(&db, file); + let global_table = index.symbol_table(FileScopeId::global()); + + assert_eq!(names(&global_table), vec!["context", "x", "y"]); + + let use_def = index.use_def_map(FileScopeId::global()); + for name in ["x", "y"] { + let Some(definition) = use_def.first_public_definition( + global_table.symbol_id_by_name(name).expect("symbol exists"), + ) else { + panic!("Expected with item definition for {name}"); + }; + assert!(matches!(definition.node(&db), DefinitionKind::WithItem(_))); + } + } + #[test] fn dupes() { let TestCase { db, file } = test_case( diff --git a/crates/red_knot_python_semantic/src/semantic_index/builder.rs b/crates/red_knot_python_semantic/src/semantic_index/builder.rs index 860df6c257a74..049712feaf753 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/builder.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/builder.rs @@ -26,6 +26,8 @@ use crate::semantic_index::use_def::{FlowSnapshot, UseDefMapBuilder}; use crate::semantic_index::SemanticIndex; use crate::Db; +use super::definition::WithItemDefinitionNodeRef; + pub(super) struct SemanticIndexBuilder<'db> { // Builder state db: &'db dyn Db, @@ -561,6 +563,18 @@ where self.flow_merge(break_state); } } + ast::Stmt::With(ast::StmtWith { items, body, .. }) => { + for item in items { + self.visit_expr(&item.context_expr); + if let Some(optional_vars) = item.optional_vars.as_deref() { + self.add_standalone_expression(&item.context_expr); + self.current_assignment = Some(item.into()); + self.visit_expr(optional_vars); + self.current_assignment = None; + } + } + self.visit_body(body); + } ast::Stmt::Break(_) => { self.loop_break_states.push(self.flow_snapshot()); } @@ -622,6 +636,15 @@ where ComprehensionDefinitionNodeRef { node, first }, ); } + Some(CurrentAssignment::WithItem(with_item)) => { + self.add_definition( + symbol, + WithItemDefinitionNodeRef { + node: with_item, + target: name_node, + }, + ); + } None => {} } } @@ -778,6 +801,7 @@ enum CurrentAssignment<'a> { node: &'a ast::Comprehension, first: bool, }, + WithItem(&'a ast::WithItem), } impl<'a> From<&'a ast::StmtAssign> for CurrentAssignment<'a> { @@ -803,3 +827,9 @@ impl<'a> From<&'a ast::ExprNamed> for CurrentAssignment<'a> { Self::Named(value) } } + +impl<'a> From<&'a ast::WithItem> for CurrentAssignment<'a> { + fn from(value: &'a ast::WithItem) -> Self { + Self::WithItem(value) + } +} diff --git a/crates/red_knot_python_semantic/src/semantic_index/definition.rs b/crates/red_knot_python_semantic/src/semantic_index/definition.rs index 38ccaf5849f48..68c56f763fb0c 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/definition.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/definition.rs @@ -47,6 +47,7 @@ pub(crate) enum DefinitionNodeRef<'a> { AugmentedAssignment(&'a ast::StmtAugAssign), Comprehension(ComprehensionDefinitionNodeRef<'a>), Parameter(ast::AnyParameterRef<'a>), + WithItem(WithItemDefinitionNodeRef<'a>), } impl<'a> From<&'a ast::StmtFunctionDef> for DefinitionNodeRef<'a> { @@ -97,6 +98,12 @@ impl<'a> From> for DefinitionNodeRef<'a> { } } +impl<'a> From> for DefinitionNodeRef<'a> { + fn from(node_ref: WithItemDefinitionNodeRef<'a>) -> Self { + Self::WithItem(node_ref) + } +} + impl<'a> From> for DefinitionNodeRef<'a> { fn from(node: ComprehensionDefinitionNodeRef<'a>) -> Self { Self::Comprehension(node) @@ -121,6 +128,12 @@ pub(crate) struct AssignmentDefinitionNodeRef<'a> { pub(crate) target: &'a ast::ExprName, } +#[derive(Copy, Clone, Debug)] +pub(crate) struct WithItemDefinitionNodeRef<'a> { + pub(crate) node: &'a ast::WithItem, + pub(crate) target: &'a ast::ExprName, +} + #[derive(Copy, Clone, Debug)] pub(crate) struct ComprehensionDefinitionNodeRef<'a> { pub(crate) node: &'a ast::Comprehension, @@ -175,6 +188,12 @@ impl DefinitionNodeRef<'_> { DefinitionKind::ParameterWithDefault(AstNodeRef::new(parsed, parameter)) } }, + DefinitionNodeRef::WithItem(WithItemDefinitionNodeRef { node, target }) => { + DefinitionKind::WithItem(WithItemDefinitionKind { + node: AstNodeRef::new(parsed.clone(), node), + target: AstNodeRef::new(parsed, target), + }) + } } } @@ -198,6 +217,7 @@ impl DefinitionNodeRef<'_> { ast::AnyParameterRef::Variadic(parameter) => parameter.into(), ast::AnyParameterRef::NonVariadic(parameter) => parameter.into(), }, + Self::WithItem(WithItemDefinitionNodeRef { node: _, target }) => target.into(), } } } @@ -215,6 +235,7 @@ pub enum DefinitionKind { Comprehension(ComprehensionDefinitionKind), Parameter(AstNodeRef), ParameterWithDefault(AstNodeRef), + WithItem(WithItemDefinitionKind), } #[derive(Clone, Debug)] @@ -250,7 +271,6 @@ impl ImportFromDefinitionKind { } #[derive(Clone, Debug)] -#[allow(dead_code)] pub struct AssignmentDefinitionKind { assignment: AstNodeRef, target: AstNodeRef, @@ -266,6 +286,22 @@ impl AssignmentDefinitionKind { } } +#[derive(Clone, Debug)] +pub struct WithItemDefinitionKind { + node: AstNodeRef, + target: AstNodeRef, +} + +impl WithItemDefinitionKind { + pub(crate) fn node(&self) -> &ast::WithItem { + self.node.node() + } + + pub(crate) fn target(&self) -> &ast::ExprName { + self.target.node() + } +} + #[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)] pub(crate) struct DefinitionNodeKey(NodeKey); diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index 9b183727e7920..28fb0a002ccd0 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -333,6 +333,9 @@ impl<'db> TypeInferenceBuilder<'db> { DefinitionKind::ParameterWithDefault(parameter_with_default) => { self.infer_parameter_with_default_definition(parameter_with_default, definition); } + DefinitionKind::WithItem(with_item) => { + self.infer_with_item_definition(with_item.target(), with_item.node(), definition); + } } } @@ -618,13 +621,42 @@ impl<'db> TypeInferenceBuilder<'db> { } = with_statement; for item in items { - self.infer_expression(&item.context_expr); - self.infer_optional_expression(item.optional_vars.as_deref()); + match item.optional_vars.as_deref() { + Some(ast::Expr::Name(name)) => { + self.infer_definition(name); + } + _ => { + // TODO infer definitions in unpacking assignment + self.infer_expression(&item.context_expr); + } + } } self.infer_body(body); } + fn infer_with_item_definition( + &mut self, + target: &ast::ExprName, + with_item: &ast::WithItem, + definition: Definition<'db>, + ) { + let expression = self.index.expression(&with_item.context_expr); + let result = infer_expression_types(self.db, expression); + self.extend(result); + + // TODO(dhruvmanila): The correct type inference here is the return type of the __enter__ + // method of the context manager. + let context_expr_ty = self + .types + .expression_ty(with_item.context_expr.scoped_ast_id(self.db, self.scope)); + + self.types + .expressions + .insert(target.scoped_ast_id(self.db, self.scope), context_expr_ty); + self.types.definitions.insert(definition, context_expr_ty); + } + fn infer_match_statement(&mut self, match_statement: &ast::StmtMatch) { let ast::StmtMatch { range: _,