Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,11 @@ def f4(x: float, y: float):
reveal_type(x // y) # revealed: int | float
reveal_type(x % y) # revealed: int | float
```

If any of the union elements leads to a division by zero, we will report an error:

```py
def f5(m: int, n: Literal[-1, 0, 1]):
# error: [division-by-zero] "Cannot divide object of type `int` by zero"
return m / n
```
77 changes: 53 additions & 24 deletions crates/red_knot_python_semantic/src/types/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -971,32 +971,39 @@ impl<'db> TypeInferenceBuilder<'db> {
/// Raise a diagnostic if the given type cannot be divided by zero.
///
/// Expects the resolved type of the left side of the binary expression.
fn check_division_by_zero(&mut self, expr: &ast::ExprBinOp, left: Type<'db>) {
fn check_division_by_zero(
&mut self,
node: AnyNodeRef<'_>,
op: ast::Operator,
left: Type<'db>,
) -> bool {
match left {
Type::BooleanLiteral(_) | Type::IntLiteral(_) => {}
Type::Instance(instance)
if matches!(
instance.class().known(self.db()),
Some(KnownClass::Float | KnownClass::Int | KnownClass::Bool)
) => {}
_ => return,
_ => return false,
};

let (op, by_zero) = match expr.op {
let (op, by_zero) = match op {
ast::Operator::Div => ("divide", "by zero"),
ast::Operator::FloorDiv => ("floor divide", "by zero"),
ast::Operator::Mod => ("reduce", "modulo zero"),
_ => return,
_ => return false,
};

self.context.report_lint(
&DIVISION_BY_ZERO,
expr,
node,
format_args!(
"Cannot {op} object of type `{}` {by_zero}",
left.display(self.db())
),
);

true
}

fn add_binding(&mut self, node: AnyNodeRef, binding: Definition<'db>, ty: Type<'db>) {
Expand Down Expand Up @@ -2858,7 +2865,7 @@ impl<'db> TypeInferenceBuilder<'db> {

// Fall back to non-augmented binary operator inference.
let mut binary_return_ty = || {
self.infer_binary_expression_type(target_type, value_type, op)
self.infer_binary_expression_type(assignment.into(), false, target_type, value_type, op)
.unwrap_or_else(|| {
report_unsupported_augmented_op(&mut self.context);
Type::unknown()
Expand Down Expand Up @@ -4495,19 +4502,7 @@ impl<'db> TypeInferenceBuilder<'db> {
let left_ty = self.infer_expression(left);
let right_ty = self.infer_expression(right);

// Check for division by zero; this doesn't change the inferred type for the expression, but
// may emit a diagnostic
if matches!(
(op, right_ty),
(
ast::Operator::Div | ast::Operator::FloorDiv | ast::Operator::Mod,
Type::IntLiteral(0) | Type::BooleanLiteral(false)
)
) {
self.check_division_by_zero(binary, left_ty);
}

self.infer_binary_expression_type(left_ty, right_ty, *op)
self.infer_binary_expression_type(binary.into(), false, left_ty, right_ty, *op)
.unwrap_or_else(|| {
self.context.report_lint(
&UNSUPPORTED_OPERATOR,
Expand All @@ -4524,23 +4519,51 @@ impl<'db> TypeInferenceBuilder<'db> {

fn infer_binary_expression_type(
&mut self,
node: AnyNodeRef<'_>,
mut emitted_division_by_zero_diagnostic: bool,
left_ty: Type<'db>,
right_ty: Type<'db>,
op: ast::Operator,
) -> Option<Type<'db>> {
// Check for division by zero; this doesn't change the inferred type for the expression, but
// may emit a diagnostic
if !emitted_division_by_zero_diagnostic
&& matches!(
(op, right_ty),
(
ast::Operator::Div | ast::Operator::FloorDiv | ast::Operator::Mod,
Type::IntLiteral(0) | Type::BooleanLiteral(false)
)
)
{
emitted_division_by_zero_diagnostic = self.check_division_by_zero(node, op, left_ty);
}

match (left_ty, right_ty, op) {
(Type::Union(lhs_union), rhs, _) => {
let mut union = UnionBuilder::new(self.db());
for lhs in lhs_union.elements(self.db()) {
let result = self.infer_binary_expression_type(*lhs, rhs, op)?;
let result = self.infer_binary_expression_type(
node,
emitted_division_by_zero_diagnostic,
*lhs,
rhs,
op,
)?;
union = union.add(result);
}
Some(union.build())
}
(lhs, Type::Union(rhs_union), _) => {
let mut union = UnionBuilder::new(self.db());
for rhs in rhs_union.elements(self.db()) {
let result = self.infer_binary_expression_type(lhs, *rhs, op)?;
let result = self.infer_binary_expression_type(
node,
emitted_division_by_zero_diagnostic,
lhs,
*rhs,
op,
)?;
union = union.add(result);
}
Some(union.build())
Expand Down Expand Up @@ -4659,13 +4682,19 @@ impl<'db> TypeInferenceBuilder<'db> {
}

(Type::BooleanLiteral(bool_value), right, op) => self.infer_binary_expression_type(
node,
emitted_division_by_zero_diagnostic,
Type::IntLiteral(i64::from(bool_value)),
right,
op,
),
(left, Type::BooleanLiteral(bool_value), op) => {
self.infer_binary_expression_type(left, Type::IntLiteral(i64::from(bool_value)), op)
}
(left, Type::BooleanLiteral(bool_value), op) => self.infer_binary_expression_type(
node,
emitted_division_by_zero_diagnostic,
left,
Type::IntLiteral(i64::from(bool_value)),
op,
),

(Type::Tuple(lhs), Type::Tuple(rhs), ast::Operator::Add) => {
// Note: this only works on heterogeneous tuples.
Expand Down
Loading