Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ if True and (x := 1):

```py
def _(flag: bool):
flag or (x := 1) or reveal_type(x) # revealed: Literal[1]
flag or (x := 1) or reveal_type(x) # revealed: Never

# error: [unresolved-reference]
flag or reveal_type(y) or (y := 1) # revealed: Unknown
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -223,3 +223,15 @@ def _(x: str | None, y: str | None):
if y is not x:
reveal_type(y) # revealed: str | None
```

## Assignment expressions

```py
def f() -> bool:
return True

if x := f():
reveal_type(x) # revealed: Literal[True]
else:
reveal_type(x) # revealed: Literal[False]
```
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,16 @@ def _(flag1: bool, flag2: bool):
# TODO should be Never
reveal_type(x) # revealed: Literal[1, 2]
```

## Assignment expressions

```py
def f() -> int | str | None: ...

if isinstance(x := f(), int):
reveal_type(x) # revealed: int
elif isinstance(x, str):
reveal_type(x) # revealed: str & ~int
else:
reveal_type(x) # revealed: None
```
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,17 @@ def _(x: Literal[1, "a", "b", "c", "d"]):
else:
reveal_type(x) # revealed: Literal[1, "d"]
```

## Assignment expressions

```py
from typing import Literal

def f() -> Literal[1, 2, 3]:
return 1

if (x := f()) in (1,):
reveal_type(x) # revealed: Literal[1]
else:
reveal_type(x) # revealed: Literal[2, 3]
```
Original file line number Diff line number Diff line change
Expand Up @@ -100,3 +100,16 @@ def _(flag: bool):
else:
reveal_type(x) # revealed: Literal[42]
```

## Assignment expressions

```py
from typing import Literal

def f() -> Literal[1, 2] | None: ...

if (x := f()) is None:
reveal_type(x) # revealed: None
else:
reveal_type(x) # revealed: Literal[1, 2]
```
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,14 @@ def _(x_flag: bool, y_flag: bool):
reveal_type(x) # revealed: bool
reveal_type(y) # revealed: bool
```

## Assignment expressions

```py
def f() -> int | str | None: ...

if (x := f()) is not None:
reveal_type(x) # revealed: int | str
else:
reveal_type(x) # revealed: None
```
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,18 @@ def _(flag1: bool, flag2: bool, a: int):
else:
reveal_type(x) # revealed: Literal[1, 2]
```

## Assignment expressions

```py
from typing import Literal

def f() -> Literal[1, 2, 3]:
return 1

if (x := f()) != 1:
reveal_type(x) # revealed: Literal[2, 3]
else:
# TODO should be Literal[1]
reveal_type(x) # revealed: Literal[1, 2, 3]
```
127 changes: 84 additions & 43 deletions crates/red_knot_python_semantic/src/types/narrow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,8 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
self.evaluate_expression_node_predicate(&unary_op.operand, expression, !is_positive)
}
ast::Expr::BoolOp(bool_op) => self.evaluate_bool_op(bool_op, expression, is_positive),
_ => None, // TODO other test expression kinds
ast::Expr::Named(expr_named) => self.evaluate_expr_named(expr_named, is_positive),
_ => None,
}
}

Expand Down Expand Up @@ -343,6 +344,18 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
NarrowingConstraints::from_iter([(symbol, ty)])
}

fn evaluate_expr_named(
&mut self,
expr_named: &ast::ExprNamed,
is_positive: bool,
) -> Option<NarrowingConstraints<'db>> {
if let ast::Expr::Name(expr_name) = expr_named.target.as_ref() {
Some(self.evaluate_expr_name(expr_name, is_positive))
} else {
None
}
}

fn evaluate_expr_in(&mut self, lhs_ty: Type<'db>, rhs_ty: Type<'db>) -> Option<Type<'db>> {
if lhs_ty.is_single_valued(self.db) || lhs_ty.is_union_of_single_valued(self.db) {
match rhs_ty {
Expand All @@ -365,14 +378,55 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
}
}

fn evaluate_expr_compare_op(
&mut self,
lhs_ty: Type<'db>,
rhs_ty: Type<'db>,
op: ast::CmpOp,
) -> Option<Type<'db>> {
match op {
ast::CmpOp::IsNot => {
if rhs_ty.is_singleton(self.db) {
let ty = IntersectionBuilder::new(self.db)
.add_negative(rhs_ty)
.build();
Some(ty)
} else {
// Non-singletons cannot be safely narrowed using `is not`
None
}
}
ast::CmpOp::Is => Some(rhs_ty),
ast::CmpOp::NotEq => {
if rhs_ty.is_single_valued(self.db) {
let ty = IntersectionBuilder::new(self.db)
.add_negative(rhs_ty)
.build();
Some(ty)
} else {
None
}
}
ast::CmpOp::Eq if lhs_ty.is_literal_string() => Some(rhs_ty),
ast::CmpOp::In => self.evaluate_expr_in(lhs_ty, rhs_ty),
ast::CmpOp::NotIn => self
.evaluate_expr_in(lhs_ty, rhs_ty)
.map(|ty| ty.negate(self.db)),
_ => None,
}
}

fn evaluate_expr_compare(
&mut self,
expr_compare: &ast::ExprCompare,
expression: Expression<'db>,
is_positive: bool,
) -> Option<NarrowingConstraints<'db>> {
fn is_narrowing_target_candidate(expr: &ast::Expr) -> bool {
matches!(expr, ast::Expr::Name(_) | ast::Expr::Call(_))
matches!(
expr,
ast::Expr::Name(_) | ast::Expr::Call(_) | ast::Expr::Named(_)
)
}

let ast::ExprCompare {
Expand Down Expand Up @@ -423,43 +477,24 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
}) => {
let symbol = self.expect_expr_name_symbol(id);

match if is_positive { *op } else { op.negate() } {
ast::CmpOp::IsNot => {
if rhs_ty.is_singleton(self.db) {
let ty = IntersectionBuilder::new(self.db)
.add_negative(rhs_ty)
.build();
constraints.insert(symbol, ty);
} else {
// Non-singletons cannot be safely narrowed using `is not`
}
}
ast::CmpOp::Is => {
constraints.insert(symbol, rhs_ty);
}
ast::CmpOp::NotEq => {
if rhs_ty.is_single_valued(self.db) {
let ty = IntersectionBuilder::new(self.db)
.add_negative(rhs_ty)
.build();
constraints.insert(symbol, ty);
}
}
ast::CmpOp::Eq if lhs_ty.is_literal_string() => {
constraints.insert(symbol, rhs_ty);
}
ast::CmpOp::In => {
if let Some(ty) = self.evaluate_expr_in(lhs_ty, rhs_ty) {
constraints.insert(symbol, ty);
}
}
ast::CmpOp::NotIn => {
if let Some(ty) = self.evaluate_expr_in(lhs_ty, rhs_ty) {
constraints.insert(symbol, ty.negate(self.db));
}
}
_ => {
// TODO other comparison types
let op = if is_positive { *op } else { op.negate() };

if let Some(ty) = self.evaluate_expr_compare_op(lhs_ty, rhs_ty, op) {
constraints.insert(symbol, ty);
}
}
ast::Expr::Named(ast::ExprNamed {
range: _,
target,
value: _,
}) => {
if let ast::Expr::Name(ast::ExprName { id, .. }) = target.as_ref() {
let symbol = self.expect_expr_name_symbol(id);

let op = if is_positive { *op } else { op.negate() };

if let Some(ty) = self.evaluate_expr_compare_op(lhs_ty, rhs_ty, op) {
constraints.insert(symbol, ty);
}
}
}
Expand Down Expand Up @@ -535,10 +570,16 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
Type::FunctionLiteral(function_type) if expr_call.arguments.keywords.is_empty() => {
let function = function_type.known(self.db)?.into_constraint_function()?;

let [ast::Expr::Name(ast::ExprName { id, .. }), class_info] =
&*expr_call.arguments.args
else {
return None;
let (id, class_info) = match &*expr_call.arguments.args {
[first, class_info] => match first {
ast::Expr::Named(ast::ExprNamed { target, .. }) => match target.as_ref() {
ast::Expr::Name(ast::ExprName { id, .. }) => (id, class_info),
_ => return None,
},
ast::Expr::Name(ast::ExprName { id, .. }) => (id, class_info),
_ => return None,
},
_ => return None,
};

let symbol = self.expect_expr_name_symbol(id);
Expand Down
Loading