From 4dc2c257efa674b65ce4929d88a770cf0fd011f5 Mon Sep 17 00:00:00 2001 From: Alex Waygood Date: Wed, 11 Sep 2024 15:05:40 -0400 Subject: [PATCH] [red-knot] Fix type inference for `except*` definitions (#13320) --- .../src/semantic_index/builder.rs | 73 +++++++---- .../src/semantic_index/definition.rs | 44 +++++-- .../src/types/infer.rs | 113 +++++++++++++++--- 3 files changed, 179 insertions(+), 51 deletions(-) diff --git a/crates/red_knot_python_semantic/src/semantic_index/builder.rs b/crates/red_knot_python_semantic/src/semantic_index/builder.rs index 8fb5a7f9412cc..25d9b9159ba3b 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/builder.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/builder.rs @@ -27,7 +27,9 @@ use crate::semantic_index::SemanticIndex; use crate::Db; use super::constraint::{Constraint, PatternConstraint}; -use super::definition::{MatchPatternDefinitionNodeRef, WithItemDefinitionNodeRef}; +use super::definition::{ + ExceptHandlerDefinitionNodeRef, MatchPatternDefinitionNodeRef, WithItemDefinitionNodeRef, +}; pub(super) struct SemanticIndexBuilder<'db> { // Builder state @@ -696,6 +698,51 @@ where self.flow_merge(after_subject); } } + ast::Stmt::Try(ast::StmtTry { + body, + handlers, + orelse, + finalbody, + is_star, + range: _, + }) => { + self.visit_body(body); + + for except_handler in handlers { + let ast::ExceptHandler::ExceptHandler(except_handler) = except_handler; + let ast::ExceptHandlerExceptHandler { + name: symbol_name, + type_: handled_exceptions, + body: handler_body, + range: _, + } = except_handler; + + if let Some(handled_exceptions) = handled_exceptions { + self.visit_expr(handled_exceptions); + } + + // If `handled_exceptions` above was `None`, it's something like `except as e:`, + // which is invalid syntax. However, it's still pretty obvious here that the user + // *wanted* `e` to be bound, so we should still create a definition here nonetheless. + if let Some(symbol_name) = symbol_name { + let symbol = self + .add_or_update_symbol(symbol_name.id.clone(), SymbolFlags::IS_DEFINED); + + self.add_definition( + symbol, + DefinitionNodeRef::ExceptHandler(ExceptHandlerDefinitionNodeRef { + handler: except_handler, + is_star: *is_star, + }), + ); + } + + self.visit_body(handler_body); + } + + self.visit_body(orelse); + self.visit_body(finalbody); + } _ => { walk_stmt(self, stmt); } @@ -958,30 +1005,6 @@ where self.current_match_case.as_mut().unwrap().index += 1; } - - fn visit_except_handler(&mut self, except_handler: &'ast ast::ExceptHandler) { - let ast::ExceptHandler::ExceptHandler(except_handler) = except_handler; - let ast::ExceptHandlerExceptHandler { - name: symbol_name, - type_: handled_exceptions, - body, - range: _, - } = except_handler; - - if let Some(handled_exceptions) = handled_exceptions { - self.visit_expr(handled_exceptions); - } - - // If `handled_exceptions` above was `None`, it's something like `except as e:`, - // which is invalid syntax. However, it's still pretty obvious here that the user - // *wanted* `e` to be bound, so we should still create a definition here nonetheless. - if let Some(symbol_name) = symbol_name { - let symbol = self.add_or_update_symbol(symbol_name.id.clone(), SymbolFlags::IS_DEFINED); - self.add_definition(symbol, except_handler); - } - - self.visit_body(body); - } } #[derive(Copy, Clone, Debug)] diff --git a/crates/red_knot_python_semantic/src/semantic_index/definition.rs b/crates/red_knot_python_semantic/src/semantic_index/definition.rs index 00d51a3a06012..0f7f1a5b15066 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/definition.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/definition.rs @@ -50,7 +50,7 @@ pub(crate) enum DefinitionNodeRef<'a> { Parameter(ast::AnyParameterRef<'a>), WithItem(WithItemDefinitionNodeRef<'a>), MatchPattern(MatchPatternDefinitionNodeRef<'a>), - ExceptHandler(&'a ast::ExceptHandlerExceptHandler), + ExceptHandler(ExceptHandlerDefinitionNodeRef<'a>), } impl<'a> From<&'a ast::StmtFunctionDef> for DefinitionNodeRef<'a> { @@ -131,12 +131,6 @@ impl<'a> From> for DefinitionNodeRef<'a> { } } -impl<'a> From<&'a ast::ExceptHandlerExceptHandler> for DefinitionNodeRef<'a> { - fn from(node: &'a ast::ExceptHandlerExceptHandler) -> Self { - Self::ExceptHandler(node) - } -} - #[derive(Copy, Clone, Debug)] pub(crate) struct ImportFromDefinitionNodeRef<'a> { pub(crate) node: &'a ast::StmtImportFrom, @@ -162,6 +156,12 @@ pub(crate) struct ForStmtDefinitionNodeRef<'a> { pub(crate) is_async: bool, } +#[derive(Copy, Clone, Debug)] +pub(crate) struct ExceptHandlerDefinitionNodeRef<'a> { + pub(crate) handler: &'a ast::ExceptHandlerExceptHandler, + pub(crate) is_star: bool, +} + #[derive(Copy, Clone, Debug)] pub(crate) struct ComprehensionDefinitionNodeRef<'a> { pub(crate) iterable: &'a ast::Expr, @@ -258,9 +258,13 @@ impl DefinitionNodeRef<'_> { identifier: AstNodeRef::new(parsed, identifier), index, }), - DefinitionNodeRef::ExceptHandler(handler) => { - DefinitionKind::ExceptHandler(AstNodeRef::new(parsed, handler)) - } + DefinitionNodeRef::ExceptHandler(ExceptHandlerDefinitionNodeRef { + handler, + is_star, + }) => DefinitionKind::ExceptHandler(ExceptHandlerDefinitionKind { + handler: AstNodeRef::new(parsed.clone(), handler), + is_star, + }), } } @@ -293,7 +297,7 @@ impl DefinitionNodeRef<'_> { Self::MatchPattern(MatchPatternDefinitionNodeRef { identifier, .. }) => { identifier.into() } - Self::ExceptHandler(handler) => handler.into(), + Self::ExceptHandler(ExceptHandlerDefinitionNodeRef { handler, .. }) => handler.into(), } } } @@ -314,7 +318,7 @@ pub enum DefinitionKind { ParameterWithDefault(AstNodeRef), WithItem(WithItemDefinitionKind), MatchPattern(MatchPatternDefinitionKind), - ExceptHandler(AstNodeRef), + ExceptHandler(ExceptHandlerDefinitionKind), } #[derive(Clone, Debug)] @@ -430,6 +434,22 @@ impl ForStmtDefinitionKind { } } +#[derive(Clone, Debug)] +pub struct ExceptHandlerDefinitionKind { + handler: AstNodeRef, + is_star: bool, +} + +impl ExceptHandlerDefinitionKind { + pub(crate) fn handled_exceptions(&self) -> Option<&ast::Expr> { + self.handler.node().type_.as_deref() + } + + pub(crate) fn is_star(&self) -> bool { + self.is_star + } +} + #[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)] pub(crate) struct DefinitionNodeKey(NodeKey); diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index 30afbafc3c06b..e5415a8b868d2 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -40,7 +40,9 @@ use ruff_text_size::Ranged; use crate::module_name::ModuleName; use crate::module_resolver::{file_to_module, resolve_module}; use crate::semantic_index::ast_ids::{HasScopedAstId, HasScopedUseId, ScopedExpressionId}; -use crate::semantic_index::definition::{Definition, DefinitionKind, DefinitionNodeKey}; +use crate::semantic_index::definition::{ + Definition, DefinitionKind, DefinitionNodeKey, ExceptHandlerDefinitionKind, +}; use crate::semantic_index::expression::Expression; use crate::semantic_index::semantic_index; use crate::semantic_index::symbol::{NodeWithScopeKind, NodeWithScopeRef, ScopeId}; @@ -426,8 +428,8 @@ impl<'db> TypeInferenceBuilder<'db> { definition, ); } - DefinitionKind::ExceptHandler(handler) => { - self.infer_except_handler_definition(handler, definition); + DefinitionKind::ExceptHandler(except_handler_definition) => { + self.infer_except_handler_definition(except_handler_definition, definition); } } } @@ -821,22 +823,29 @@ impl<'db> TypeInferenceBuilder<'db> { fn infer_except_handler_definition( &mut self, - handler: &'db ast::ExceptHandlerExceptHandler, + except_handler_definition: &ExceptHandlerDefinitionKind, definition: Definition<'db>, ) { - let node_ty = handler - .type_ - .as_deref() + let node_ty = except_handler_definition + .handled_exceptions() .map(|ty| self.infer_expression(ty)) .unwrap_or(Type::Unknown); - // TODO: anything that's a consistent subtype of - // `type[BaseException] | tuple[type[BaseException], ...]` should be valid; - // anything else should be invalid --Alex - let symbol_ty = match node_ty { - Type::Any | Type::Unknown => node_ty, - Type::Class(class_ty) => Type::Instance(class_ty), - _ => Type::Unknown, + let symbol_ty = if except_handler_definition.is_star() { + // TODO should be generic --Alex + // + // TODO should infer `ExceptionGroup` if all caught exceptions + // are subclasses of `Exception` --Alex + builtins_symbol_ty(self.db, "BaseExceptionGroup").to_instance(self.db) + } else { + // TODO: anything that's a consistent subtype of + // `type[BaseException] | tuple[type[BaseException], ...]` should be valid; + // anything else should be invalid --Alex + match node_ty { + Type::Any | Type::Unknown => node_ty, + Type::Class(class_ty) => Type::Instance(class_ty), + _ => Type::Unknown, + } }; self.types.definitions.insert(definition, symbol_ty); @@ -4563,6 +4572,82 @@ mod tests { Ok(()) } + #[test] + fn except_star_handler_baseexception() -> anyhow::Result<()> { + let mut db = setup_db(); + + db.write_dedented( + "src/a.py", + " + try: + x + except* BaseException as e: + pass + ", + )?; + + assert_file_diagnostics(&db, "src/a.py", &[]); + + // TODO: once we support `sys.version_info` branches, + // we can set `--target-version=py311` in this test + // and the inferred type will just be `BaseExceptionGroup` --Alex + assert_public_ty(&db, "src/a.py", "e", "Unknown | BaseExceptionGroup"); + + Ok(()) + } + + #[test] + fn except_star_handler() -> anyhow::Result<()> { + let mut db = setup_db(); + + db.write_dedented( + "src/a.py", + " + try: + x + except* OSError as e: + pass + ", + )?; + + assert_file_diagnostics(&db, "src/a.py", &[]); + + // TODO: once we support `sys.version_info` branches, + // we can set `--target-version=py311` in this test + // and the inferred type will just be `BaseExceptionGroup` --Alex + // + // TODO more precise would be `ExceptionGroup[OSError]` --Alex + assert_public_ty(&db, "src/a.py", "e", "Unknown | BaseExceptionGroup"); + + Ok(()) + } + + #[test] + fn except_star_handler_multiple_types() -> anyhow::Result<()> { + let mut db = setup_db(); + + db.write_dedented( + "src/a.py", + " + try: + x + except* (TypeError, AttributeError) as e: + pass + ", + )?; + + assert_file_diagnostics(&db, "src/a.py", &[]); + + // TODO: once we support `sys.version_info` branches, + // we can set `--target-version=py311` in this test + // and the inferred type will just be `BaseExceptionGroup` --Alex + // + // TODO more precise would be `ExceptionGroup[TypeError | AttributeError]` --Alex + assert_public_ty(&db, "src/a.py", "e", "Unknown | BaseExceptionGroup"); + + Ok(()) + } + #[test] fn basic_comprehension() -> anyhow::Result<()> { let mut db = setup_db();