Skip to content

Commit

Permalink
[red-knot] Fix type inference for except* definitions (#13320)
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexWaygood authored Sep 11, 2024
1 parent b72d49b commit 4dc2c25
Show file tree
Hide file tree
Showing 3 changed files with 179 additions and 51 deletions.
73 changes: 48 additions & 25 deletions crates/red_knot_python_semantic/src/semantic_index/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ use crate::semantic_index::SemanticIndex;
use crate::Db;

use super::constraint::{Constraint, PatternConstraint};
use super::definition::{MatchPatternDefinitionNodeRef, WithItemDefinitionNodeRef};
use super::definition::{
ExceptHandlerDefinitionNodeRef, MatchPatternDefinitionNodeRef, WithItemDefinitionNodeRef,
};

pub(super) struct SemanticIndexBuilder<'db> {
// Builder state
Expand Down Expand Up @@ -696,6 +698,51 @@ where
self.flow_merge(after_subject);
}
}
ast::Stmt::Try(ast::StmtTry {
body,
handlers,
orelse,
finalbody,
is_star,
range: _,
}) => {
self.visit_body(body);

for except_handler in handlers {
let ast::ExceptHandler::ExceptHandler(except_handler) = except_handler;
let ast::ExceptHandlerExceptHandler {
name: symbol_name,
type_: handled_exceptions,
body: handler_body,
range: _,
} = except_handler;

if let Some(handled_exceptions) = handled_exceptions {
self.visit_expr(handled_exceptions);
}

// If `handled_exceptions` above was `None`, it's something like `except as e:`,
// which is invalid syntax. However, it's still pretty obvious here that the user
// *wanted* `e` to be bound, so we should still create a definition here nonetheless.
if let Some(symbol_name) = symbol_name {
let symbol = self
.add_or_update_symbol(symbol_name.id.clone(), SymbolFlags::IS_DEFINED);

self.add_definition(
symbol,
DefinitionNodeRef::ExceptHandler(ExceptHandlerDefinitionNodeRef {
handler: except_handler,
is_star: *is_star,
}),
);
}

self.visit_body(handler_body);
}

self.visit_body(orelse);
self.visit_body(finalbody);
}
_ => {
walk_stmt(self, stmt);
}
Expand Down Expand Up @@ -958,30 +1005,6 @@ where

self.current_match_case.as_mut().unwrap().index += 1;
}

fn visit_except_handler(&mut self, except_handler: &'ast ast::ExceptHandler) {
let ast::ExceptHandler::ExceptHandler(except_handler) = except_handler;
let ast::ExceptHandlerExceptHandler {
name: symbol_name,
type_: handled_exceptions,
body,
range: _,
} = except_handler;

if let Some(handled_exceptions) = handled_exceptions {
self.visit_expr(handled_exceptions);
}

// If `handled_exceptions` above was `None`, it's something like `except as e:`,
// which is invalid syntax. However, it's still pretty obvious here that the user
// *wanted* `e` to be bound, so we should still create a definition here nonetheless.
if let Some(symbol_name) = symbol_name {
let symbol = self.add_or_update_symbol(symbol_name.id.clone(), SymbolFlags::IS_DEFINED);
self.add_definition(symbol, except_handler);
}

self.visit_body(body);
}
}

#[derive(Copy, Clone, Debug)]
Expand Down
44 changes: 32 additions & 12 deletions crates/red_knot_python_semantic/src/semantic_index/definition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ pub(crate) enum DefinitionNodeRef<'a> {
Parameter(ast::AnyParameterRef<'a>),
WithItem(WithItemDefinitionNodeRef<'a>),
MatchPattern(MatchPatternDefinitionNodeRef<'a>),
ExceptHandler(&'a ast::ExceptHandlerExceptHandler),
ExceptHandler(ExceptHandlerDefinitionNodeRef<'a>),
}

impl<'a> From<&'a ast::StmtFunctionDef> for DefinitionNodeRef<'a> {
Expand Down Expand Up @@ -131,12 +131,6 @@ impl<'a> From<MatchPatternDefinitionNodeRef<'a>> for DefinitionNodeRef<'a> {
}
}

impl<'a> From<&'a ast::ExceptHandlerExceptHandler> for DefinitionNodeRef<'a> {
fn from(node: &'a ast::ExceptHandlerExceptHandler) -> Self {
Self::ExceptHandler(node)
}
}

#[derive(Copy, Clone, Debug)]
pub(crate) struct ImportFromDefinitionNodeRef<'a> {
pub(crate) node: &'a ast::StmtImportFrom,
Expand All @@ -162,6 +156,12 @@ pub(crate) struct ForStmtDefinitionNodeRef<'a> {
pub(crate) is_async: bool,
}

#[derive(Copy, Clone, Debug)]
pub(crate) struct ExceptHandlerDefinitionNodeRef<'a> {
pub(crate) handler: &'a ast::ExceptHandlerExceptHandler,
pub(crate) is_star: bool,
}

#[derive(Copy, Clone, Debug)]
pub(crate) struct ComprehensionDefinitionNodeRef<'a> {
pub(crate) iterable: &'a ast::Expr,
Expand Down Expand Up @@ -258,9 +258,13 @@ impl DefinitionNodeRef<'_> {
identifier: AstNodeRef::new(parsed, identifier),
index,
}),
DefinitionNodeRef::ExceptHandler(handler) => {
DefinitionKind::ExceptHandler(AstNodeRef::new(parsed, handler))
}
DefinitionNodeRef::ExceptHandler(ExceptHandlerDefinitionNodeRef {
handler,
is_star,
}) => DefinitionKind::ExceptHandler(ExceptHandlerDefinitionKind {
handler: AstNodeRef::new(parsed.clone(), handler),
is_star,
}),
}
}

Expand Down Expand Up @@ -293,7 +297,7 @@ impl DefinitionNodeRef<'_> {
Self::MatchPattern(MatchPatternDefinitionNodeRef { identifier, .. }) => {
identifier.into()
}
Self::ExceptHandler(handler) => handler.into(),
Self::ExceptHandler(ExceptHandlerDefinitionNodeRef { handler, .. }) => handler.into(),
}
}
}
Expand All @@ -314,7 +318,7 @@ pub enum DefinitionKind {
ParameterWithDefault(AstNodeRef<ast::ParameterWithDefault>),
WithItem(WithItemDefinitionKind),
MatchPattern(MatchPatternDefinitionKind),
ExceptHandler(AstNodeRef<ast::ExceptHandlerExceptHandler>),
ExceptHandler(ExceptHandlerDefinitionKind),
}

#[derive(Clone, Debug)]
Expand Down Expand Up @@ -430,6 +434,22 @@ impl ForStmtDefinitionKind {
}
}

#[derive(Clone, Debug)]
pub struct ExceptHandlerDefinitionKind {
handler: AstNodeRef<ast::ExceptHandlerExceptHandler>,
is_star: bool,
}

impl ExceptHandlerDefinitionKind {
pub(crate) fn handled_exceptions(&self) -> Option<&ast::Expr> {
self.handler.node().type_.as_deref()
}

pub(crate) fn is_star(&self) -> bool {
self.is_star
}
}

#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
pub(crate) struct DefinitionNodeKey(NodeKey);

Expand Down
113 changes: 99 additions & 14 deletions crates/red_knot_python_semantic/src/types/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ use ruff_text_size::Ranged;
use crate::module_name::ModuleName;
use crate::module_resolver::{file_to_module, resolve_module};
use crate::semantic_index::ast_ids::{HasScopedAstId, HasScopedUseId, ScopedExpressionId};
use crate::semantic_index::definition::{Definition, DefinitionKind, DefinitionNodeKey};
use crate::semantic_index::definition::{
Definition, DefinitionKind, DefinitionNodeKey, ExceptHandlerDefinitionKind,
};
use crate::semantic_index::expression::Expression;
use crate::semantic_index::semantic_index;
use crate::semantic_index::symbol::{NodeWithScopeKind, NodeWithScopeRef, ScopeId};
Expand Down Expand Up @@ -426,8 +428,8 @@ impl<'db> TypeInferenceBuilder<'db> {
definition,
);
}
DefinitionKind::ExceptHandler(handler) => {
self.infer_except_handler_definition(handler, definition);
DefinitionKind::ExceptHandler(except_handler_definition) => {
self.infer_except_handler_definition(except_handler_definition, definition);
}
}
}
Expand Down Expand Up @@ -821,22 +823,29 @@ impl<'db> TypeInferenceBuilder<'db> {

fn infer_except_handler_definition(
&mut self,
handler: &'db ast::ExceptHandlerExceptHandler,
except_handler_definition: &ExceptHandlerDefinitionKind,
definition: Definition<'db>,
) {
let node_ty = handler
.type_
.as_deref()
let node_ty = except_handler_definition
.handled_exceptions()
.map(|ty| self.infer_expression(ty))
.unwrap_or(Type::Unknown);

// TODO: anything that's a consistent subtype of
// `type[BaseException] | tuple[type[BaseException], ...]` should be valid;
// anything else should be invalid --Alex
let symbol_ty = match node_ty {
Type::Any | Type::Unknown => node_ty,
Type::Class(class_ty) => Type::Instance(class_ty),
_ => Type::Unknown,
let symbol_ty = if except_handler_definition.is_star() {
// TODO should be generic --Alex
//
// TODO should infer `ExceptionGroup` if all caught exceptions
// are subclasses of `Exception` --Alex
builtins_symbol_ty(self.db, "BaseExceptionGroup").to_instance(self.db)
} else {
// TODO: anything that's a consistent subtype of
// `type[BaseException] | tuple[type[BaseException], ...]` should be valid;
// anything else should be invalid --Alex
match node_ty {
Type::Any | Type::Unknown => node_ty,
Type::Class(class_ty) => Type::Instance(class_ty),
_ => Type::Unknown,
}
};

self.types.definitions.insert(definition, symbol_ty);
Expand Down Expand Up @@ -4563,6 +4572,82 @@ mod tests {
Ok(())
}

#[test]
fn except_star_handler_baseexception() -> anyhow::Result<()> {
let mut db = setup_db();

db.write_dedented(
"src/a.py",
"
try:
x
except* BaseException as e:
pass
",
)?;

assert_file_diagnostics(&db, "src/a.py", &[]);

// TODO: once we support `sys.version_info` branches,
// we can set `--target-version=py311` in this test
// and the inferred type will just be `BaseExceptionGroup` --Alex
assert_public_ty(&db, "src/a.py", "e", "Unknown | BaseExceptionGroup");

Ok(())
}

#[test]
fn except_star_handler() -> anyhow::Result<()> {
let mut db = setup_db();

db.write_dedented(
"src/a.py",
"
try:
x
except* OSError as e:
pass
",
)?;

assert_file_diagnostics(&db, "src/a.py", &[]);

// TODO: once we support `sys.version_info` branches,
// we can set `--target-version=py311` in this test
// and the inferred type will just be `BaseExceptionGroup` --Alex
//
// TODO more precise would be `ExceptionGroup[OSError]` --Alex
assert_public_ty(&db, "src/a.py", "e", "Unknown | BaseExceptionGroup");

Ok(())
}

#[test]
fn except_star_handler_multiple_types() -> anyhow::Result<()> {
let mut db = setup_db();

db.write_dedented(
"src/a.py",
"
try:
x
except* (TypeError, AttributeError) as e:
pass
",
)?;

assert_file_diagnostics(&db, "src/a.py", &[]);

// TODO: once we support `sys.version_info` branches,
// we can set `--target-version=py311` in this test
// and the inferred type will just be `BaseExceptionGroup` --Alex
//
// TODO more precise would be `ExceptionGroup[TypeError | AttributeError]` --Alex
assert_public_ty(&db, "src/a.py", "e", "Unknown | BaseExceptionGroup");

Ok(())
}

#[test]
fn basic_comprehension() -> anyhow::Result<()> {
let mut db = setup_db();
Expand Down

0 comments on commit 4dc2c25

Please sign in to comment.