diff --git a/crates/red_knot_python_semantic/resources/mdtest/narrow/type.md b/crates/red_knot_python_semantic/resources/mdtest/narrow/type.md index 927670b5081af..602265039db27 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/narrow/type.md +++ b/crates/red_knot_python_semantic/resources/mdtest/narrow/type.md @@ -144,3 +144,13 @@ def _(x: Base): # express a constraint like `Base & ~ProperSubtypeOf[Base]`. reveal_type(x) # revealed: Base ``` + +## Assignment expressions + +```py +def _(x: object): + if (y := type(x)) is bool: + reveal_type(y) # revealed: Literal[bool] + if (type(y := x)) is bool: + reveal_type(y) # revealed: bool +``` diff --git a/crates/red_knot_python_semantic/src/types/narrow.rs b/crates/red_knot_python_semantic/src/types/narrow.rs index 50cfb9b931b84..b884d1da850ac 100644 --- a/crates/red_knot_python_semantic/src/types/narrow.rs +++ b/crates/red_knot_python_semantic/src/types/narrow.rs @@ -238,6 +238,17 @@ fn negate_if<'db>(constraints: &mut NarrowingConstraints<'db>, db: &'db dyn Db, } } +fn expr_name(expr: &ast::Expr) -> Option<&ast::name::Name> { + match expr { + ast::Expr::Named(ast::ExprNamed { target, .. }) => match target.as_ref() { + ast::Expr::Name(ast::ExprName { id, .. }) => Some(id), + _ => None, + }, + ast::Expr::Name(ast::ExprName { id, .. }) => Some(id), + _ => None, + } +} + struct NarrowingConstraintsBuilder<'db> { db: &'db dyn Db, predicate: PredicateNode<'db>, @@ -497,27 +508,9 @@ impl<'db> NarrowingConstraintsBuilder<'db> { last_rhs_ty = Some(rhs_ty); match left { - ast::Expr::Name(ast::ExprName { - range: _, - id, - ctx: _, - }) => { - let symbol = self.expect_expr_name_symbol(id); - - let op = if is_positive { *op } else { op.negate() }; - - if let Some(ty) = self.evaluate_expr_compare_op(lhs_ty, rhs_ty, op) { - constraints.insert(symbol, ty); - } - } - ast::Expr::Named(ast::ExprNamed { - range: _, - target, - value: _, - }) => { - if let ast::Expr::Name(ast::ExprName { id, .. }) = target.as_ref() { + ast::Expr::Name(_) | ast::Expr::Named(_) => { + if let Some(id) = expr_name(left) { let symbol = self.expect_expr_name_symbol(id); - let op = if is_positive { *op } else { op.negate() }; if let Some(ty) = self.evaluate_expr_compare_op(lhs_ty, rhs_ty, op) { @@ -545,8 +538,12 @@ impl<'db> NarrowingConstraintsBuilder<'db> { } }; - let [ast::Expr::Name(ast::ExprName { id, .. })] = &**args else { - continue; + let id = match &**args { + [first] => match expr_name(first) { + Some(id) => id, + None => continue, + }, + _ => continue, }; let is_valid_constraint = if is_positive { @@ -598,13 +595,9 @@ impl<'db> NarrowingConstraintsBuilder<'db> { let function = function_type.known(self.db)?.into_constraint_function()?; let (id, class_info) = match &*expr_call.arguments.args { - [first, class_info] => match first { - ast::Expr::Named(ast::ExprNamed { target, .. }) => match target.as_ref() { - ast::Expr::Name(ast::ExprName { id, .. }) => (id, class_info), - _ => return None, - }, - ast::Expr::Name(ast::ExprName { id, .. }) => (id, class_info), - _ => return None, + [first, class_info] => match expr_name(first) { + Some(id) => (id, class_info), + None => return None, }, _ => return None, };