From 3b006de3299ef5e1195bac270d19732f04a2b3ac Mon Sep 17 00:00:00 2001 From: David Peter Date: Wed, 2 Apr 2025 18:08:30 +0200 Subject: [PATCH] [red-knot] Detect division-by-zero in unions and intersections --- .../resources/mdtest/binary/unions.md | 8 ++ .../src/types/infer.rs | 77 +++++++++++++------ 2 files changed, 61 insertions(+), 24 deletions(-) diff --git a/crates/red_knot_python_semantic/resources/mdtest/binary/unions.md b/crates/red_knot_python_semantic/resources/mdtest/binary/unions.md index 0f5fb09bf5aff..1ec0794cc4db7 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/binary/unions.md +++ b/crates/red_knot_python_semantic/resources/mdtest/binary/unions.md @@ -49,3 +49,11 @@ def f4(x: float, y: float): reveal_type(x // y) # revealed: int | float reveal_type(x % y) # revealed: int | float ``` + +If any of the union elements leads to a division by zero, we will report an error: + +```py +def f5(m: int, n: Literal[-1, 0, 1]): + # error: [division-by-zero] "Cannot divide object of type `int` by zero" + return m / n +``` diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index ba38c2cee263b..243f68a83f8c2 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -971,7 +971,12 @@ impl<'db> TypeInferenceBuilder<'db> { /// Raise a diagnostic if the given type cannot be divided by zero. /// /// Expects the resolved type of the left side of the binary expression. - fn check_division_by_zero(&mut self, expr: &ast::ExprBinOp, left: Type<'db>) { + fn check_division_by_zero( + &mut self, + node: AnyNodeRef<'_>, + op: ast::Operator, + left: Type<'db>, + ) -> bool { match left { Type::BooleanLiteral(_) | Type::IntLiteral(_) => {} Type::Instance(instance) @@ -979,24 +984,26 @@ impl<'db> TypeInferenceBuilder<'db> { instance.class().known(self.db()), Some(KnownClass::Float | KnownClass::Int | KnownClass::Bool) ) => {} - _ => return, + _ => return false, }; - let (op, by_zero) = match expr.op { + let (op, by_zero) = match op { ast::Operator::Div => ("divide", "by zero"), ast::Operator::FloorDiv => ("floor divide", "by zero"), ast::Operator::Mod => ("reduce", "modulo zero"), - _ => return, + _ => return false, }; self.context.report_lint( &DIVISION_BY_ZERO, - expr, + node, format_args!( "Cannot {op} object of type `{}` {by_zero}", left.display(self.db()) ), ); + + true } fn add_binding(&mut self, node: AnyNodeRef, binding: Definition<'db>, ty: Type<'db>) { @@ -2858,7 +2865,7 @@ impl<'db> TypeInferenceBuilder<'db> { // Fall back to non-augmented binary operator inference. let mut binary_return_ty = || { - self.infer_binary_expression_type(target_type, value_type, op) + self.infer_binary_expression_type(assignment.into(), false, target_type, value_type, op) .unwrap_or_else(|| { report_unsupported_augmented_op(&mut self.context); Type::unknown() @@ -4495,19 +4502,7 @@ impl<'db> TypeInferenceBuilder<'db> { let left_ty = self.infer_expression(left); let right_ty = self.infer_expression(right); - // Check for division by zero; this doesn't change the inferred type for the expression, but - // may emit a diagnostic - if matches!( - (op, right_ty), - ( - ast::Operator::Div | ast::Operator::FloorDiv | ast::Operator::Mod, - Type::IntLiteral(0) | Type::BooleanLiteral(false) - ) - ) { - self.check_division_by_zero(binary, left_ty); - } - - self.infer_binary_expression_type(left_ty, right_ty, *op) + self.infer_binary_expression_type(binary.into(), false, left_ty, right_ty, *op) .unwrap_or_else(|| { self.context.report_lint( &UNSUPPORTED_OPERATOR, @@ -4524,15 +4519,37 @@ impl<'db> TypeInferenceBuilder<'db> { fn infer_binary_expression_type( &mut self, + node: AnyNodeRef<'_>, + mut emitted_division_by_zero_diagnostic: bool, left_ty: Type<'db>, right_ty: Type<'db>, op: ast::Operator, ) -> Option> { + // Check for division by zero; this doesn't change the inferred type for the expression, but + // may emit a diagnostic + if !emitted_division_by_zero_diagnostic + && matches!( + (op, right_ty), + ( + ast::Operator::Div | ast::Operator::FloorDiv | ast::Operator::Mod, + Type::IntLiteral(0) | Type::BooleanLiteral(false) + ) + ) + { + emitted_division_by_zero_diagnostic = self.check_division_by_zero(node, op, left_ty); + } + match (left_ty, right_ty, op) { (Type::Union(lhs_union), rhs, _) => { let mut union = UnionBuilder::new(self.db()); for lhs in lhs_union.elements(self.db()) { - let result = self.infer_binary_expression_type(*lhs, rhs, op)?; + let result = self.infer_binary_expression_type( + node, + emitted_division_by_zero_diagnostic, + *lhs, + rhs, + op, + )?; union = union.add(result); } Some(union.build()) @@ -4540,7 +4557,13 @@ impl<'db> TypeInferenceBuilder<'db> { (lhs, Type::Union(rhs_union), _) => { let mut union = UnionBuilder::new(self.db()); for rhs in rhs_union.elements(self.db()) { - let result = self.infer_binary_expression_type(lhs, *rhs, op)?; + let result = self.infer_binary_expression_type( + node, + emitted_division_by_zero_diagnostic, + lhs, + *rhs, + op, + )?; union = union.add(result); } Some(union.build()) @@ -4659,13 +4682,19 @@ impl<'db> TypeInferenceBuilder<'db> { } (Type::BooleanLiteral(bool_value), right, op) => self.infer_binary_expression_type( + node, + emitted_division_by_zero_diagnostic, Type::IntLiteral(i64::from(bool_value)), right, op, ), - (left, Type::BooleanLiteral(bool_value), op) => { - self.infer_binary_expression_type(left, Type::IntLiteral(i64::from(bool_value)), op) - } + (left, Type::BooleanLiteral(bool_value), op) => self.infer_binary_expression_type( + node, + emitted_division_by_zero_diagnostic, + left, + Type::IntLiteral(i64::from(bool_value)), + op, + ), (Type::Tuple(lhs), Type::Tuple(rhs), ast::Operator::Add) => { // Note: this only works on heterogeneous tuples.