From c1a33de3d9b2a7692bce85dfa1f5f2aa4f1bfa69 Mon Sep 17 00:00:00 2001 From: Charlie Marsh Date: Sun, 13 Aug 2023 16:39:02 -0400 Subject: [PATCH] Add support for unions to our Python builtins type system --- Cargo.lock | 1 + .../fixtures/pylint/bad_string_format_type.py | 2 + .../pylint/rules/bad_string_format_type.rs | 34 +- .../pylint/rules/invalid_envvar_value.rs | 6 +- .../rules/pylint/rules/invalid_str_return.rs | 6 +- ...ts__PLE1307_bad_string_format_type.py.snap | 24 +- ...ests__PLE1507_invalid_envvar_value.py.snap | 20 + crates/ruff_python_semantic/Cargo.toml | 5 +- .../src/analyze/type_inference.rs | 532 +++++++++++++++--- 9 files changed, 529 insertions(+), 101 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index b7341c4264c3a..ed93d2eb4df85 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2423,6 +2423,7 @@ dependencies = [ "num-traits", "ruff_index", "ruff_python_ast", + "ruff_python_parser", "ruff_python_stdlib", "ruff_source_file", "ruff_text_size", diff --git a/crates/ruff/resources/test/fixtures/pylint/bad_string_format_type.py b/crates/ruff/resources/test/fixtures/pylint/bad_string_format_type.py index 66d6f2a5fd203..3fe6722401507 100644 --- a/crates/ruff/resources/test/fixtures/pylint/bad_string_format_type.py +++ b/crates/ruff/resources/test/fixtures/pylint/bad_string_format_type.py @@ -13,6 +13,7 @@ "%(key)d" % {"key": []} print("%d" % ("%s" % ("nested",),)) "%d" % ((1, 2, 3),) +"%d" % (1 if x > 0 else []) # False negatives WORD = "abc" @@ -55,3 +56,4 @@ "%d" % (len(foo),) '(%r, %r, %r, %r)' % (hostname, address, username, '$PASSWORD') '%r' % ({'server_school_roles': server_school_roles, 'is_school_multiserver_domain': is_school_multiserver_domain}, ) +"%d" % (1 if x > 0 else 2) diff --git a/crates/ruff/src/rules/pylint/rules/bad_string_format_type.rs b/crates/ruff/src/rules/pylint/rules/bad_string_format_type.rs index 1f2905751d0fe..dfb7283f60f9d 100644 --- a/crates/ruff/src/rules/pylint/rules/bad_string_format_type.rs +++ b/crates/ruff/src/rules/pylint/rules/bad_string_format_type.rs @@ -9,7 +9,7 @@ use rustc_hash::FxHashMap; use ruff_diagnostics::{Diagnostic, Violation}; use ruff_macros::{derive_message_formats, violation}; use ruff_python_ast::str::{leading_quote, trailing_quote}; -use ruff_python_semantic::analyze::type_inference::PythonType; +use ruff_python_semantic::analyze::type_inference::{NumberLike, PythonType, ResolvedPythonType}; use crate::checkers::ast::Checker; @@ -59,14 +59,16 @@ impl FormatType { | PythonType::Set | PythonType::Tuple | PythonType::Generator - | PythonType::Complex - | PythonType::Bool | PythonType::Ellipsis | PythonType::None => matches!( self, FormatType::Unknown | FormatType::String | FormatType::Repr ), - PythonType::Integer => matches!( + PythonType::Number(NumberLike::Complex | NumberLike::Bool) => matches!( + self, + FormatType::Unknown | FormatType::String | FormatType::Repr + ), + PythonType::Number(NumberLike::Integer) => matches!( self, FormatType::Unknown | FormatType::String @@ -75,7 +77,7 @@ impl FormatType { | FormatType::Float | FormatType::Number ), - PythonType::Float => matches!( + PythonType::Number(NumberLike::Float) => matches!( self, FormatType::Unknown | FormatType::String @@ -83,7 +85,6 @@ impl FormatType { | FormatType::Float | FormatType::Number ), - PythonType::Unknown => true, } } } @@ -118,16 +119,22 @@ fn collect_specs(formats: &[CFormatStrOrBytes]) -> Vec<&CFormatSpec> { /// Return `true` if the format string is equivalent to the constant type fn equivalent(format: &CFormatSpec, value: &Expr) -> bool { - let format: FormatType = format.format_char.into(); - let constant: PythonType = value.into(); - format.is_compatible_with(constant) + let format = FormatType::from(format.format_char); + match ResolvedPythonType::from(value) { + ResolvedPythonType::Atom(atom) => format.is_compatible_with(atom), + ResolvedPythonType::Union(atoms) => { + atoms.iter().all(|atom| format.is_compatible_with(*atom)) + } + ResolvedPythonType::Unknown => true, + ResolvedPythonType::TypeError => true, + } } -/// Return `true` if the [`Constnat`] aligns with the format type. +/// Return `true` if the [`Constant`] aligns with the format type. fn is_valid_constant(formats: &[CFormatStrOrBytes], value: &Expr) -> bool { let formats = collect_specs(formats); - // If there is more than one format, this is not valid python and we should - // return true so that no error is reported + // If there is more than one format, this is not valid Python and we should + // return true so that no error is reported. let [format] = formats.as_slice() else { return true; }; @@ -242,8 +249,7 @@ pub(crate) fn bad_string_format_type(checker: &mut Checker, expr: &Expr, right: values, range: _, }) => is_valid_dict(&format_strings, keys, values), - Expr::Constant(_) => is_valid_constant(&format_strings, right), - _ => true, + _ => is_valid_constant(&format_strings, right), }; if !is_valid { checker diff --git a/crates/ruff/src/rules/pylint/rules/invalid_envvar_value.rs b/crates/ruff/src/rules/pylint/rules/invalid_envvar_value.rs index 89c6e509bee28..174a6a1359743 100644 --- a/crates/ruff/src/rules/pylint/rules/invalid_envvar_value.rs +++ b/crates/ruff/src/rules/pylint/rules/invalid_envvar_value.rs @@ -1,7 +1,7 @@ use ruff_diagnostics::{Diagnostic, Violation}; use ruff_macros::{derive_message_formats, violation}; use ruff_python_ast::{self as ast, Ranged}; -use ruff_python_semantic::analyze::type_inference::PythonType; +use ruff_python_semantic::analyze::type_inference::{PythonType, ResolvedPythonType}; use crate::checkers::ast::Checker; @@ -46,8 +46,8 @@ pub(crate) fn invalid_envvar_value(checker: &mut Checker, call: &ast::ExprCall) }; if matches!( - PythonType::from(expr), - PythonType::String | PythonType::Unknown + ResolvedPythonType::from(expr), + ResolvedPythonType::Unknown | ResolvedPythonType::Atom(PythonType::String) ) { return; } diff --git a/crates/ruff/src/rules/pylint/rules/invalid_str_return.rs b/crates/ruff/src/rules/pylint/rules/invalid_str_return.rs index 57d00bca939aa..3dd341a8f5f28 100644 --- a/crates/ruff/src/rules/pylint/rules/invalid_str_return.rs +++ b/crates/ruff/src/rules/pylint/rules/invalid_str_return.rs @@ -3,7 +3,7 @@ use ruff_python_ast::{Ranged, Stmt}; use ruff_diagnostics::{Diagnostic, Violation}; use ruff_macros::{derive_message_formats, violation}; use ruff_python_ast::{helpers::ReturnStatementVisitor, statement_visitor::StatementVisitor}; -use ruff_python_semantic::analyze::type_inference::PythonType; +use ruff_python_semantic::analyze::type_inference::{PythonType, ResolvedPythonType}; use crate::checkers::ast::Checker; @@ -42,8 +42,8 @@ pub(crate) fn invalid_str_return(checker: &mut Checker, name: &str, body: &[Stmt for stmt in returns { if let Some(value) = stmt.value.as_deref() { if !matches!( - PythonType::from(value), - PythonType::String | PythonType::Unknown + ResolvedPythonType::from(value), + ResolvedPythonType::Unknown | ResolvedPythonType::Atom(PythonType::String) ) { checker .diagnostics diff --git a/crates/ruff/src/rules/pylint/snapshots/ruff__rules__pylint__tests__PLE1307_bad_string_format_type.py.snap b/crates/ruff/src/rules/pylint/snapshots/ruff__rules__pylint__tests__PLE1307_bad_string_format_type.py.snap index a0366a476ba51..e50006d76397e 100644 --- a/crates/ruff/src/rules/pylint/snapshots/ruff__rules__pylint__tests__PLE1307_bad_string_format_type.py.snap +++ b/crates/ruff/src/rules/pylint/snapshots/ruff__rules__pylint__tests__PLE1307_bad_string_format_type.py.snap @@ -69,6 +69,16 @@ bad_string_format_type.py:10:1: PLE1307 Format type does not match argument type 12 | "%d" % ([],) | +bad_string_format_type.py:11:1: PLE1307 Format type does not match argument type + | + 9 | "%x" % 1.1 +10 | "%(key)x" % {"key": 1.1} +11 | "%d" % [] + | ^^^^^^^^^ PLE1307 +12 | "%d" % ([],) +13 | "%(key)d" % {"key": []} + | + bad_string_format_type.py:12:1: PLE1307 Format type does not match argument type | 10 | "%(key)x" % {"key": 1.1} @@ -96,6 +106,7 @@ bad_string_format_type.py:14:7: PLE1307 Format type does not match argument type 14 | print("%d" % ("%s" % ("nested",),)) | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ PLE1307 15 | "%d" % ((1, 2, 3),) +16 | "%d" % (1 if x > 0 else []) | bad_string_format_type.py:15:1: PLE1307 Format type does not match argument type @@ -104,8 +115,17 @@ bad_string_format_type.py:15:1: PLE1307 Format type does not match argument type 14 | print("%d" % ("%s" % ("nested",),)) 15 | "%d" % ((1, 2, 3),) | ^^^^^^^^^^^^^^^^^^^ PLE1307 -16 | -17 | # False negatives +16 | "%d" % (1 if x > 0 else []) + | + +bad_string_format_type.py:16:1: PLE1307 Format type does not match argument type + | +14 | print("%d" % ("%s" % ("nested",),)) +15 | "%d" % ((1, 2, 3),) +16 | "%d" % (1 if x > 0 else []) + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^ PLE1307 +17 | +18 | # False negatives | diff --git a/crates/ruff/src/rules/pylint/snapshots/ruff__rules__pylint__tests__PLE1507_invalid_envvar_value.py.snap b/crates/ruff/src/rules/pylint/snapshots/ruff__rules__pylint__tests__PLE1507_invalid_envvar_value.py.snap index d40d4ce4dc434..6564ff6dd041a 100644 --- a/crates/ruff/src/rules/pylint/snapshots/ruff__rules__pylint__tests__PLE1507_invalid_envvar_value.py.snap +++ b/crates/ruff/src/rules/pylint/snapshots/ruff__rules__pylint__tests__PLE1507_invalid_envvar_value.py.snap @@ -31,4 +31,24 @@ invalid_envvar_value.py:8:11: PLE1507 Invalid type for initial `os.getenv` argum 10 | os.getenv(key=f"foo", default="bar") | +invalid_envvar_value.py:12:15: PLE1507 Invalid type for initial `os.getenv` argument; expected `str` + | +10 | os.getenv(key=f"foo", default="bar") +11 | os.getenv(key="foo" + "bar", default=1) +12 | os.getenv(key=1 + "bar", default=1) # [invalid-envvar-value] + | ^^^^^^^^^ PLE1507 +13 | os.getenv("PATH_TEST" if using_clear_path else "PATH_ORIG") +14 | os.getenv(1 if using_clear_path else "PATH_ORIG") + | + +invalid_envvar_value.py:14:11: PLE1507 Invalid type for initial `os.getenv` argument; expected `str` + | +12 | os.getenv(key=1 + "bar", default=1) # [invalid-envvar-value] +13 | os.getenv("PATH_TEST" if using_clear_path else "PATH_ORIG") +14 | os.getenv(1 if using_clear_path else "PATH_ORIG") + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ PLE1507 +15 | +16 | AA = "aa" + | + diff --git a/crates/ruff_python_semantic/Cargo.toml b/crates/ruff_python_semantic/Cargo.toml index 9b37767711b83..58484d140d24b 100644 --- a/crates/ruff_python_semantic/Cargo.toml +++ b/crates/ruff_python_semantic/Cargo.toml @@ -23,6 +23,7 @@ bitflags = { workspace = true } is-macro = { workspace = true } num-traits = { workspace = true } rustc-hash = { workspace = true } - - smallvec = { workspace = true } + +[dev-dependencies] +ruff_python_parser = { path = "../ruff_python_parser" } diff --git a/crates/ruff_python_semantic/src/analyze/type_inference.rs b/crates/ruff_python_semantic/src/analyze/type_inference.rs index 52445620dbd7e..4bba5b9826010 100644 --- a/crates/ruff_python_semantic/src/analyze/type_inference.rs +++ b/crates/ruff_python_semantic/src/analyze/type_inference.rs @@ -1,7 +1,317 @@ //! Analysis rules to perform basic type inference on individual expressions. +use rustc_hash::FxHashSet; + use ruff_python_ast as ast; -use ruff_python_ast::{Constant, Expr, Operator}; +use ruff_python_ast::{Constant, Expr, Operator, UnaryOp}; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum ResolvedPythonType { + /// The expression resolved to a single known type, like `str` or `int`. + Atom(PythonType), + /// The expression resolved to a union of known types, like `str | int`. + Union(FxHashSet), + /// The expression resolved to an unknown type, like a variable or function call. + Unknown, + /// The expression resolved to a `TypeError`, like `1 + "hello"`. + TypeError, +} + +impl ResolvedPythonType { + #[must_use] + pub fn union(self, other: Self) -> Self { + match (self, other) { + (Self::TypeError, _) | (_, Self::TypeError) => Self::TypeError, + (Self::Unknown, _) | (_, Self::Unknown) => Self::Unknown, + (Self::Atom(a), Self::Atom(b)) => { + if a == b { + Self::Atom(a) + } else { + Self::Union(FxHashSet::from_iter([a, b])) + } + } + (Self::Atom(a), Self::Union(mut b)) => { + b.insert(a); + Self::Union(b) + } + (Self::Union(mut a), Self::Atom(b)) => { + a.insert(b); + Self::Union(a) + } + (Self::Union(mut a), Self::Union(b)) => { + a.extend(b); + Self::Union(a) + } + } + } +} + +impl From<&Expr> for ResolvedPythonType { + fn from(expr: &Expr) -> Self { + match expr { + // Primitives. + Expr::Dict(_) => ResolvedPythonType::Atom(PythonType::Dict), + Expr::DictComp(_) => ResolvedPythonType::Atom(PythonType::Dict), + Expr::Set(_) => ResolvedPythonType::Atom(PythonType::Set), + Expr::SetComp(_) => ResolvedPythonType::Atom(PythonType::Set), + Expr::List(_) => ResolvedPythonType::Atom(PythonType::List), + Expr::ListComp(_) => ResolvedPythonType::Atom(PythonType::List), + Expr::Tuple(_) => ResolvedPythonType::Atom(PythonType::Tuple), + Expr::GeneratorExp(_) => ResolvedPythonType::Atom(PythonType::Generator), + Expr::FString(_) => ResolvedPythonType::Atom(PythonType::String), + Expr::Constant(ast::ExprConstant { value, .. }) => match value { + Constant::Str(_) => ResolvedPythonType::Atom(PythonType::String), + Constant::Int(_) => { + ResolvedPythonType::Atom(PythonType::Number(NumberLike::Integer)) + } + Constant::Float(_) => { + ResolvedPythonType::Atom(PythonType::Number(NumberLike::Float)) + } + Constant::Bool(_) => ResolvedPythonType::Atom(PythonType::Number(NumberLike::Bool)), + Constant::Complex { .. } => { + ResolvedPythonType::Atom(PythonType::Number(NumberLike::Complex)) + } + Constant::None => ResolvedPythonType::Atom(PythonType::None), + Constant::Ellipsis => ResolvedPythonType::Atom(PythonType::Ellipsis), + Constant::Bytes(_) => ResolvedPythonType::Atom(PythonType::Bytes), + }, + // Simple container expressions. + Expr::NamedExpr(ast::ExprNamedExpr { value, .. }) => { + ResolvedPythonType::from(value.as_ref()) + } + Expr::IfExp(ast::ExprIfExp { body, orelse, .. }) => { + let body = ResolvedPythonType::from(body.as_ref()); + let orelse = ResolvedPythonType::from(orelse.as_ref()); + body.union(orelse) + } + + // Boolean operators. + Expr::BoolOp(ast::ExprBoolOp { values, .. }) => values + .iter() + .map(ResolvedPythonType::from) + .reduce(ResolvedPythonType::union) + .unwrap_or(ResolvedPythonType::Unknown), + + // Unary operators. + Expr::UnaryOp(ast::ExprUnaryOp { operand, op, .. }) => match op { + UnaryOp::Invert => { + return match ResolvedPythonType::from(operand.as_ref()) { + ResolvedPythonType::Atom(PythonType::Number( + NumberLike::Bool | NumberLike::Integer, + )) => ResolvedPythonType::Atom(PythonType::Number(NumberLike::Integer)), + ResolvedPythonType::Atom(_) => ResolvedPythonType::TypeError, + _ => ResolvedPythonType::Unknown, + } + } + // Ex) `not 1.0` + UnaryOp::Not => ResolvedPythonType::Atom(PythonType::Number(NumberLike::Bool)), + // Ex) `+1` or `-1` + UnaryOp::UAdd | UnaryOp::USub => { + return match ResolvedPythonType::from(operand.as_ref()) { + ResolvedPythonType::Atom(PythonType::Number(number)) => { + ResolvedPythonType::Atom(PythonType::Number( + if number == NumberLike::Bool { + NumberLike::Integer + } else { + number + }, + )) + } + ResolvedPythonType::Atom(_) => ResolvedPythonType::TypeError, + _ => ResolvedPythonType::Unknown, + } + } + }, + + // Binary operators. + Expr::BinOp(ast::ExprBinOp { + left, op, right, .. + }) => { + match op { + Operator::Add => { + match ( + ResolvedPythonType::from(left.as_ref()), + ResolvedPythonType::from(right.as_ref()), + ) { + // Ex) `"Hello" + "world"` + ( + ResolvedPythonType::Atom(PythonType::String), + ResolvedPythonType::Atom(PythonType::String), + ) => return ResolvedPythonType::Atom(PythonType::String), + // Ex) `b"Hello" + b"world"` + ( + ResolvedPythonType::Atom(PythonType::Bytes), + ResolvedPythonType::Atom(PythonType::Bytes), + ) => return ResolvedPythonType::Atom(PythonType::Bytes), + // Ex) `[1] + [2]` + ( + ResolvedPythonType::Atom(PythonType::List), + ResolvedPythonType::Atom(PythonType::List), + ) => return ResolvedPythonType::Atom(PythonType::List), + // Ex) `(1, 2) + (3, 4)` + ( + ResolvedPythonType::Atom(PythonType::Tuple), + ResolvedPythonType::Atom(PythonType::Tuple), + ) => return ResolvedPythonType::Atom(PythonType::Tuple), + // Ex) `1 + 1.0` + ( + ResolvedPythonType::Atom(PythonType::Number(left)), + ResolvedPythonType::Atom(PythonType::Number(right)), + ) => { + return ResolvedPythonType::Atom(PythonType::Number( + left.coerce(right), + )); + } + // Ex) `"a" + 1` + (ResolvedPythonType::Atom(_), ResolvedPythonType::Atom(_)) => { + return ResolvedPythonType::TypeError; + } + _ => {} + } + } + Operator::Sub => { + match ( + ResolvedPythonType::from(left.as_ref()), + ResolvedPythonType::from(right.as_ref()), + ) { + // Ex) `1 - 1` + ( + ResolvedPythonType::Atom(PythonType::Number(left)), + ResolvedPythonType::Atom(PythonType::Number(right)), + ) => { + return ResolvedPythonType::Atom(PythonType::Number( + left.coerce(right), + )); + } + // Ex) `{1, 2} - {2}` + ( + ResolvedPythonType::Atom(PythonType::Set), + ResolvedPythonType::Atom(PythonType::Set), + ) => return ResolvedPythonType::Atom(PythonType::Set), + // Ex) `"a" - "b"` + (ResolvedPythonType::Atom(_), ResolvedPythonType::Atom(_)) => { + return ResolvedPythonType::TypeError; + } + _ => {} + } + } + // Ex) "a" % "b" + Operator::Mod => match ( + ResolvedPythonType::from(left.as_ref()), + ResolvedPythonType::from(right.as_ref()), + ) { + // Ex) `"Hello" % "world"` + (ResolvedPythonType::Atom(PythonType::String), _) => { + return ResolvedPythonType::Atom(PythonType::String) + } + // Ex) `b"Hello" % b"world"` + (ResolvedPythonType::Atom(PythonType::Bytes), _) => { + return ResolvedPythonType::Atom(PythonType::Bytes) + } + // Ex) `1 % 2` + ( + ResolvedPythonType::Atom(PythonType::Number(left)), + ResolvedPythonType::Atom(PythonType::Number(right)), + ) => { + return ResolvedPythonType::Atom(PythonType::Number( + left.coerce(right), + )); + } + _ => {} + }, + // Standard arithmetic operators, which coerce to the "highest" number type. + Operator::Mult | Operator::FloorDiv | Operator::Pow => match ( + ResolvedPythonType::from(left.as_ref()), + ResolvedPythonType::from(right.as_ref()), + ) { + // Ex) `1 - 2` + ( + ResolvedPythonType::Atom(PythonType::Number(left)), + ResolvedPythonType::Atom(PythonType::Number(right)), + ) => { + return ResolvedPythonType::Atom(PythonType::Number( + left.coerce(right), + )); + } + (ResolvedPythonType::Atom(_), ResolvedPythonType::Atom(_)) => { + return ResolvedPythonType::TypeError; + } + _ => {} + }, + // Division, which returns at least `float`. + Operator::Div => match ( + ResolvedPythonType::from(left.as_ref()), + ResolvedPythonType::from(right.as_ref()), + ) { + // Ex) `1 / 2` + ( + ResolvedPythonType::Atom(PythonType::Number(left)), + ResolvedPythonType::Atom(PythonType::Number(right)), + ) => { + let resolved = left.coerce(right); + return ResolvedPythonType::Atom(PythonType::Number( + if resolved == NumberLike::Integer { + NumberLike::Float + } else { + resolved + }, + )); + } + (ResolvedPythonType::Atom(_), ResolvedPythonType::Atom(_)) => { + return ResolvedPythonType::TypeError; + } + _ => {} + }, + // Bitwise operators, which only work on `int` and `bool`. + Operator::BitAnd + | Operator::BitOr + | Operator::BitXor + | Operator::LShift + | Operator::RShift => { + match ( + ResolvedPythonType::from(left.as_ref()), + ResolvedPythonType::from(right.as_ref()), + ) { + // Ex) `1 & 2` + ( + ResolvedPythonType::Atom(PythonType::Number(left)), + ResolvedPythonType::Atom(PythonType::Number(right)), + ) => { + let resolved = left.coerce(right); + return if resolved == NumberLike::Integer { + ResolvedPythonType::Atom(PythonType::Number( + NumberLike::Integer, + )) + } else { + ResolvedPythonType::TypeError + }; + } + (ResolvedPythonType::Atom(_), ResolvedPythonType::Atom(_)) => { + return ResolvedPythonType::TypeError; + } + _ => {} + } + } + Operator::MatMult => {} + } + ResolvedPythonType::Unknown + } + Expr::Lambda(_) + | Expr::Await(_) + | Expr::Yield(_) + | Expr::YieldFrom(_) + | Expr::Compare(_) + | Expr::Call(_) + | Expr::FormattedValue(_) + | Expr::Attribute(_) + | Expr::Subscript(_) + | Expr::Starred(_) + | Expr::Name(_) + | Expr::Slice(_) + | Expr::IpyEscapeCommand(_) => ResolvedPythonType::Unknown, + } + } +} /// An extremely simple type inference system for individual expressions. /// @@ -9,20 +319,14 @@ use ruff_python_ast::{Constant, Expr, Operator}; /// such as strings, integers, floats, and containers. It cannot infer the /// types of variables or expressions that are not statically known from /// individual AST nodes alone. -#[derive(Debug, Copy, Clone, PartialEq, Eq, is_macro::Is)] +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] pub enum PythonType { /// A string literal, such as `"hello"`. String, /// A bytes literal, such as `b"hello"`. Bytes, - /// An integer literal, such as `1` or `0x1`. - Integer, - /// A floating-point literal, such as `1.0` or `1e10`. - Float, - /// A complex literal, such as `1j` or `1+1j`. - Complex, - /// A boolean literal, such as `True` or `False`. - Bool, + /// An integer, float, or complex literal, such as `1` or `1.0`. + Number(NumberLike), /// A `None` literal, such as `None`. None, /// An ellipsis literal, such as `...`. @@ -37,75 +341,149 @@ pub enum PythonType { Tuple, /// A generator expression, such as `(x for x in range(10))`. Generator, - /// An unknown type, such as a variable or function call. - Unknown, } -impl From<&Expr> for PythonType { - fn from(expr: &Expr) -> Self { - match expr { - Expr::NamedExpr(ast::ExprNamedExpr { value, .. }) => (value.as_ref()).into(), - Expr::UnaryOp(ast::ExprUnaryOp { operand, .. }) => (operand.as_ref()).into(), - Expr::Dict(_) => PythonType::Dict, - Expr::DictComp(_) => PythonType::Dict, - Expr::Set(_) => PythonType::Set, - Expr::SetComp(_) => PythonType::Set, - Expr::List(_) => PythonType::List, - Expr::ListComp(_) => PythonType::List, - Expr::Tuple(_) => PythonType::Tuple, - Expr::GeneratorExp(_) => PythonType::Generator, - Expr::FString(_) => PythonType::String, - Expr::IfExp(ast::ExprIfExp { body, orelse, .. }) => { - let body = PythonType::from(body.as_ref()); - let orelse = PythonType::from(orelse.as_ref()); - // TODO(charlie): If we have two known types, we should return a union. As-is, - // callers that ignore the `Unknown` type will allow invalid expressions (e.g., - // if you're testing for strings, you may accept `String` or `Unknown`, and you'd - // now accept, e.g., `1 if True else "a"`, which resolves to `Unknown`). - if body == orelse { - body - } else { - PythonType::Unknown - } - } - Expr::BinOp(ast::ExprBinOp { - left, op, right, .. - }) => { - match op { - // Ex) "a" + "b" - Operator::Add => { - match ( - PythonType::from(left.as_ref()), - PythonType::from(right.as_ref()), - ) { - (PythonType::String, PythonType::String) => return PythonType::String, - (PythonType::Bytes, PythonType::Bytes) => return PythonType::Bytes, - // TODO(charlie): If we have two known types, they may be incompatible. - // Return an error (e.g., for `1 + "a"`). - _ => {} - } - } - // Ex) "a" % "b" - Operator::Mod => match PythonType::from(left.as_ref()) { - PythonType::String => return PythonType::String, - PythonType::Bytes => return PythonType::Bytes, - _ => {} - }, - _ => {} - } - PythonType::Unknown - } - Expr::Constant(ast::ExprConstant { value, .. }) => match value { - Constant::Str(_) => PythonType::String, - Constant::Int(_) => PythonType::Integer, - Constant::Float(_) => PythonType::Float, - Constant::Bool(_) => PythonType::Bool, - Constant::Complex { .. } => PythonType::Complex, - Constant::None => PythonType::None, - Constant::Ellipsis => PythonType::Ellipsis, - Constant::Bytes(_) => PythonType::Bytes, - }, - _ => PythonType::Unknown, +/// A numeric type, or a type that can be trivially coerced to a numeric type. +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] +pub enum NumberLike { + /// An integer literal, such as `1` or `0x1`. + Integer, + /// A floating-point literal, such as `1.0` or `1e10`. + Float, + /// A complex literal, such as `1j` or `1+1j`. + Complex, + /// A boolean literal, such as `True` or `False`. + Bool, +} + +impl NumberLike { + /// Coerces two number-like types to the "highest" number-like type. + #[must_use] + pub fn coerce(self, other: NumberLike) -> NumberLike { + match (self, other) { + (NumberLike::Complex, _) | (_, NumberLike::Complex) => NumberLike::Complex, + (NumberLike::Float, _) | (_, NumberLike::Float) => NumberLike::Float, + _ => NumberLike::Integer, } } } + +#[cfg(test)] +mod tests { + use rustc_hash::FxHashSet; + + use ruff_python_ast::Expr; + use ruff_python_parser::parse_expression; + + use crate::analyze::type_inference::{NumberLike, PythonType, ResolvedPythonType}; + + fn parse(expression: &str) -> Expr { + parse_expression(expression, "").unwrap() + } + + #[test] + fn type_inference() { + // Atoms. + assert_eq!( + ResolvedPythonType::from(&parse("1")), + ResolvedPythonType::Atom(PythonType::Number(NumberLike::Integer)) + ); + assert_eq!( + ResolvedPythonType::from(&parse("'Hello, world'")), + ResolvedPythonType::Atom(PythonType::String) + ); + assert_eq!( + ResolvedPythonType::from(&parse("b'Hello, world'")), + ResolvedPythonType::Atom(PythonType::Bytes) + ); + assert_eq!( + ResolvedPythonType::from(&parse("'Hello' % 'world'")), + ResolvedPythonType::Atom(PythonType::String) + ); + + // Boolean operators. + assert_eq!( + ResolvedPythonType::from(&parse("1 and 2")), + ResolvedPythonType::Atom(PythonType::Number(NumberLike::Integer)) + ); + assert_eq!( + ResolvedPythonType::from(&parse("1 and True")), + ResolvedPythonType::Union(FxHashSet::from_iter([ + PythonType::Number(NumberLike::Integer), + PythonType::Number(NumberLike::Bool) + ])) + ); + + // Binary operators. + assert_eq!( + ResolvedPythonType::from(&parse("1.0 * 2")), + ResolvedPythonType::Atom(PythonType::Number(NumberLike::Float)) + ); + assert_eq!( + ResolvedPythonType::from(&parse("2 * 1.0")), + ResolvedPythonType::Atom(PythonType::Number(NumberLike::Float)) + ); + assert_eq!( + ResolvedPythonType::from(&parse("1.0 * 2j")), + ResolvedPythonType::Atom(PythonType::Number(NumberLike::Complex)) + ); + assert_eq!( + ResolvedPythonType::from(&parse("1 / True")), + ResolvedPythonType::Atom(PythonType::Number(NumberLike::Float)) + ); + assert_eq!( + ResolvedPythonType::from(&parse("1 / 2")), + ResolvedPythonType::Atom(PythonType::Number(NumberLike::Float)) + ); + assert_eq!( + ResolvedPythonType::from(&parse("{1, 2} - {2}")), + ResolvedPythonType::Atom(PythonType::Set) + ); + + // Unary operators. + assert_eq!( + ResolvedPythonType::from(&parse("-1")), + ResolvedPythonType::Atom(PythonType::Number(NumberLike::Integer)) + ); + assert_eq!( + ResolvedPythonType::from(&parse("-1.0")), + ResolvedPythonType::Atom(PythonType::Number(NumberLike::Float)) + ); + assert_eq!( + ResolvedPythonType::from(&parse("-1j")), + ResolvedPythonType::Atom(PythonType::Number(NumberLike::Complex)) + ); + assert_eq!( + ResolvedPythonType::from(&parse("-True")), + ResolvedPythonType::Atom(PythonType::Number(NumberLike::Integer)) + ); + assert_eq!( + ResolvedPythonType::from(&parse("not 'Hello'")), + ResolvedPythonType::Atom(PythonType::Number(NumberLike::Bool)) + ); + assert_eq!( + ResolvedPythonType::from(&parse("not x.y.z")), + ResolvedPythonType::Atom(PythonType::Number(NumberLike::Bool)) + ); + + // Conditional expressions. + assert_eq!( + ResolvedPythonType::from(&parse("1 if True else 2")), + ResolvedPythonType::Atom(PythonType::Number(NumberLike::Integer)) + ); + assert_eq!( + ResolvedPythonType::from(&parse("1 if True else 2.0")), + ResolvedPythonType::Union(FxHashSet::from_iter([ + PythonType::Number(NumberLike::Integer), + PythonType::Number(NumberLike::Float) + ])) + ); + assert_eq!( + ResolvedPythonType::from(&parse("1 if True else False")), + ResolvedPythonType::Union(FxHashSet::from_iter([ + PythonType::Number(NumberLike::Integer), + PythonType::Number(NumberLike::Bool) + ])) + ); + } +}