diff --git a/crates/ty_python_semantic/resources/mdtest/binary/integers.md b/crates/ty_python_semantic/resources/mdtest/binary/integers.md index 30561981810b3..3834a90e0ba20 100644 --- a/crates/ty_python_semantic/resources/mdtest/binary/integers.md +++ b/crates/ty_python_semantic/resources/mdtest/binary/integers.md @@ -1,5 +1,11 @@ # Binary operations on integers +> Developer's note: This is mainly a test for the behavior of the type inferer. The constant +> evaluator (`resolve_to_literal`) of `SemanticIndexBuilder` is implemented separately from the type +> inferer, so if you modify the contents of this file or the type inferer, please also modify the +> implementation of `resolve_to_literal` and the unit tests (semantic_index/tests/const_eval\_\*) at +> the same time. + ## Basic Arithmetic ```py diff --git a/crates/ty_python_semantic/src/semantic_index.rs b/crates/ty_python_semantic/src/semantic_index.rs index e49d3188104ec..d64bde0cef810 100644 --- a/crates/ty_python_semantic/src/semantic_index.rs +++ b/crates/ty_python_semantic/src/semantic_index.rs @@ -925,6 +925,50 @@ mod tests { .collect() } + /// A function to test how the constant evaluator of `SemanticIndexBuilder` evaluates an expression + /// (the evaluation should match that of `TypeInferenceBuilder`). + /// For example, for the input `x = 1\nif cond: x = 2\nx`, if `cond` evaluates to `AlwaysTrue`, it returns `vec![2]`, + /// if it evaluates to `AlwaysFalse`, it returns `vec![1]`, ​​if it evaluates to `Ambiguous`, it returns `vec![1, 2]`. + fn reachable_bindings_for_terminal_use(content: &str) -> Vec { + let TestCase { db, file } = test_case(content); + let scope = global_scope(&db, file); + let module = parsed_module(&db, file).load(&db); + let ast = module.syntax(); + + let terminal_expr = ast + .body + .last() + .and_then(ast::Stmt::as_expr_stmt) + .map(|stmt| stmt.value.as_ref()) + .expect("expected terminal expression statement"); + let terminal_name = terminal_expr + .as_name_expr() + .expect("terminal expression should be a name"); + + let use_id = terminal_name.scoped_use_id(&db, scope); + let use_def = use_def_map(&db, scope); + + use_def + .bindings_at_use(use_id) + .filter_map(|binding_with_constraints| { + let definition = binding_with_constraints.binding.definition()?; + let DefinitionKind::Assignment(assignment) = definition.kind(&db) else { + return None; + }; + + let ast::Expr::NumberLiteral(ast::ExprNumberLiteral { + value: ast::Number::Int(value), + .. + }) = assignment.value(&module) + else { + return None; + }; + + value.as_i64() + }) + .collect::>() + } + #[test] fn empty() { let TestCase { db, file } = test_case(""); @@ -1590,6 +1634,71 @@ class C[T]: assert_eq!(*num, 1); } + #[test] + fn const_eval_lshift_overflow_is_ambiguous() { + let values = reachable_bindings_for_terminal_use( + " +x = 1 +if 1 << 63: + x = 2 +x +", + ); + assert_eq!(values, vec![1, 2]); + } + + #[test] + fn const_eval_lshift_zero_short_circuit() { + let values = reachable_bindings_for_terminal_use( + " +x = 1 +if 0 << 4000000000000000000: + x = 2 +x +", + ); + assert_eq!(values, vec![1]); + } + + #[test] + fn const_eval_rshift_large_positive() { + let values = reachable_bindings_for_terminal_use( + " +x = 1 +if 1 >> 5000000000: + x = 2 +x +", + ); + assert_eq!(values, vec![1]); + } + + #[test] + fn const_eval_rshift_large_negative_operand() { + let values = reachable_bindings_for_terminal_use( + " +x = 1 +if (-1) >> 5000000000: + x = 2 +x +", + ); + assert_eq!(values, vec![2]); + } + + #[test] + fn const_eval_negative_lshift_is_ambiguous() { + let values = reachable_bindings_for_terminal_use( + " +x = 1 +if 42 << -3: + x = 2 +x +", + ); + assert_eq!(values, vec![1, 2]); + } + #[test] fn expression_scope() { let TestCase { db, file } = test_case("x = 1;\ndef test():\n y = 4"); diff --git a/crates/ty_python_semantic/src/semantic_index/builder.rs b/crates/ty_python_semantic/src/semantic_index/builder.rs index 17fa4147f0625..c6387b48283c6 100644 --- a/crates/ty_python_semantic/src/semantic_index/builder.rs +++ b/crates/ty_python_semantic/src/semantic_index/builder.rs @@ -913,27 +913,253 @@ impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> { } fn build_predicate(&mut self, predicate_node: &ast::Expr) -> PredicateOrLiteral<'db> { + /// Returns if the expression is a `TYPE_CHECKING` expression. + fn is_if_type_checking(expr: &ast::Expr) -> bool { + fn is_dotted_name(expr: &ast::Expr) -> bool { + match expr { + ast::Expr::Name(_) => true, + ast::Expr::Attribute(ast::ExprAttribute { value, .. }) => is_dotted_name(value), + _ => false, + } + } + + match expr { + ast::Expr::Name(ast::ExprName { id, .. }) => id == "TYPE_CHECKING", + ast::Expr::Attribute(ast::ExprAttribute { value, attr, .. }) => { + attr == "TYPE_CHECKING" && is_dotted_name(value) + } + _ => false, + } + } + // Some commonly used test expressions are eagerly evaluated as `true` // or `false` here for performance reasons. This list does not need to // be exhaustive. More complex expressions will still evaluate to the // correct value during type-checking. fn resolve_to_literal(node: &ast::Expr) -> Option { - match node { - ast::Expr::BooleanLiteral(ast::ExprBooleanLiteral { value, .. }) => Some(*value), - ast::Expr::Name(ast::ExprName { id, .. }) if id == "TYPE_CHECKING" => Some(true), - ast::Expr::NumberLiteral(ast::ExprNumberLiteral { - value: ast::Number::Int(n), - .. - }) => Some(*n != 0), - ast::Expr::EllipsisLiteral(_) => Some(true), - ast::Expr::NoneLiteral(_) => Some(false), - ast::Expr::UnaryOp(ast::ExprUnaryOp { - op: ast::UnaryOp::Not, - operand, - .. - }) => Some(!resolve_to_literal(operand)?), - _ => None, + #[derive(Copy, Clone)] + enum ConstExpr { + Bool(bool), + Int(i64), + None, + Ellipsis, + } + + impl ConstExpr { + fn truthiness(self) -> bool { + match self { + ConstExpr::Bool(value) => value, + ConstExpr::Int(value) => value != 0, + ConstExpr::None => false, + ConstExpr::Ellipsis => true, + } + } + + fn as_int(self) -> Option { + match self { + ConstExpr::Int(value) => Some(value), + ConstExpr::Bool(value) => Some(i64::from(value)), + _ => None, + } + } } + + fn resolve_const_expr(node: &ast::Expr) -> Option { + match node { + ast::Expr::BooleanLiteral(ast::ExprBooleanLiteral { value, .. }) => { + Some(ConstExpr::Bool(*value)) + } + ast::Expr::NumberLiteral(ast::ExprNumberLiteral { + value: ast::Number::Int(n), + .. + }) => n.as_i64().map(ConstExpr::Int), + ast::Expr::EllipsisLiteral(_) => Some(ConstExpr::Ellipsis), + ast::Expr::NoneLiteral(_) => Some(ConstExpr::None), + // See also: `TypeInferenceBuilder::infer_unary_expression_type` + ast::Expr::UnaryOp(ast::ExprUnaryOp { op, operand, .. }) => { + let operand = resolve_const_expr(operand)?; + match op { + ast::UnaryOp::Not => Some(ConstExpr::Bool(!operand.truthiness())), + ast::UnaryOp::UAdd => Some(ConstExpr::Int(operand.as_int()?)), + ast::UnaryOp::USub => { + Some(ConstExpr::Int(operand.as_int()?.checked_neg()?)) + } + ast::UnaryOp::Invert => Some(ConstExpr::Int(!operand.as_int()?)), + } + } + // See also: `TypeInferenceBuilder::infer_binary_expression_type` + ast::Expr::BinOp(ast::ExprBinOp { + left, op, right, .. + }) => { + let left = resolve_const_expr(left)?.as_int()?; + let right = resolve_const_expr(right)?.as_int()?; + let value = match op { + ast::Operator::Add => left.checked_add(right)?, + ast::Operator::Sub => left.checked_sub(right)?, + ast::Operator::Mult => left.checked_mul(right)?, + ast::Operator::FloorDiv => { + let mut q = left.checked_div(right); + let r = left.checked_rem(right); + // Division works differently in Python than in Rust. If the + // result is negative and there is a remainder, floor division + // rounds down (instead of toward zero). + if left.is_negative() != right.is_negative() && r.unwrap_or(0) != 0 + { + q = q.map(|q| q - 1); + } + q? + } + ast::Operator::Mod => { + let mut r = left.checked_rem(right); + // Python's modulo keeps the sign of the divisor. Adjust the Rust + // remainder accordingly so that `q * right + r == left`. + if left.is_negative() != right.is_negative() && r.unwrap_or(0) != 0 + { + r = r.map(|x| x + right); + } + r? + } + ast::Operator::BitAnd => left & right, + ast::Operator::BitOr => left | right, + ast::Operator::BitXor => left ^ right, + ast::Operator::LShift => { + if left == 0 && right >= 0 { + 0 + } else { + // An additional overflow check beyond `checked_shl` is + // necessary here, because `checked_shl` only rejects shift + // amounts >= 64; it does not detect when significant bits + // are shifted into (or past) the sign bit. + // + // We compute the "headroom": the number of redundant + // sign-extension bits minus one (for the sign bit itself). + // A shift is safe iff `shift <= headroom`. + let headroom = if left >= 0 { + left.leading_zeros().saturating_sub(1) + } else { + left.leading_ones().saturating_sub(1) + }; + u32::try_from(right) + .ok() + .filter(|&shift| shift <= headroom) + .and_then(|shift| left.checked_shl(shift))? + } + } + ast::Operator::RShift => match u32::try_from(right) { + Ok(shift) => left >> shift.clamp(0, 63), + Err(_) if right > 0 => { + if left >= 0 { + 0 + } else { + -1 + } + } + Err(_) => return None, + }, + ast::Operator::Pow => { + let exp = u32::try_from(right).ok()?; + left.checked_pow(exp)? + } + ast::Operator::Div | ast::Operator::MatMult => return None, + }; + Some(ConstExpr::Int(value)) + } + ast::Expr::BoolOp(ast::ExprBoolOp { op, values, .. }) => { + let value = match op { + ast::BoolOp::And => { + let mut all_true = true; + for expr in values { + if !resolve_const_expr(expr)?.truthiness() { + all_true = false; + break; + } + } + all_true + } + ast::BoolOp::Or => { + let mut any_true = false; + for expr in values { + if resolve_const_expr(expr)?.truthiness() { + any_true = true; + break; + } + } + any_true + } + }; + Some(ConstExpr::Bool(value)) + } + ast::Expr::Compare(ast::ExprCompare { + left, + ops, + comparators, + .. + }) => { + let mut left_value = resolve_const_expr(left)?; + for (op, comparator) in ops.iter().zip(comparators.iter()) { + let right_value = resolve_const_expr(comparator)?; + let eq = |left: ConstExpr, right: ConstExpr| match (left, right) { + (ConstExpr::Int(left), ConstExpr::Int(right)) => { + Some(left == right) + } + (ConstExpr::None, ConstExpr::None) + | (ConstExpr::Ellipsis, ConstExpr::Ellipsis) => Some(true), + (ConstExpr::None | ConstExpr::Ellipsis, _) + | (_, ConstExpr::None | ConstExpr::Ellipsis) => Some(false), + _ => None, + }; + let result = match op { + ast::CmpOp::Eq => eq(left_value, right_value)?, + ast::CmpOp::NotEq => !eq(left_value, right_value)?, + ast::CmpOp::Lt => left_value.as_int()? < right_value.as_int()?, + ast::CmpOp::LtE => left_value.as_int()? <= right_value.as_int()?, + ast::CmpOp::Gt => left_value.as_int()? > right_value.as_int()?, + ast::CmpOp::GtE => left_value.as_int()? >= right_value.as_int()?, + ast::CmpOp::Is => match (left_value, right_value) { + (ConstExpr::None, ConstExpr::None) + | (ConstExpr::Ellipsis, ConstExpr::Ellipsis) + | (ConstExpr::Bool(true), ConstExpr::Bool(true)) + | (ConstExpr::Bool(false), ConstExpr::Bool(false)) => true, + ( + ConstExpr::None | ConstExpr::Ellipsis | ConstExpr::Bool(_), + _, + ) + | ( + _, + ConstExpr::None | ConstExpr::Ellipsis | ConstExpr::Bool(_), + ) => false, + _ => return None, + }, + ast::CmpOp::IsNot => match (left_value, right_value) { + (ConstExpr::None, ConstExpr::None) + | (ConstExpr::Ellipsis, ConstExpr::Ellipsis) + | (ConstExpr::Bool(true), ConstExpr::Bool(true)) + | (ConstExpr::Bool(false), ConstExpr::Bool(false)) => false, + ( + ConstExpr::None | ConstExpr::Ellipsis | ConstExpr::Bool(_), + _, + ) + | ( + _, + ConstExpr::None | ConstExpr::Ellipsis | ConstExpr::Bool(_), + ) => true, + _ => return None, + }, + ast::CmpOp::In | ast::CmpOp::NotIn => return None, + }; + if !result { + return Some(ConstExpr::Bool(false)); + } + left_value = right_value; + } + Some(ConstExpr::Bool(true)) + } + _ if is_if_type_checking(node) => Some(ConstExpr::Bool(true)), + _ => None, + } + } + + Some(resolve_const_expr(node)?.truthiness()) } let expression = self.add_standalone_expression(predicate_node);