diff --git a/newsfragments/5273.fixed.md b/newsfragments/5273.fixed.md new file mode 100644 index 00000000000..bbe7f10ed06 --- /dev/null +++ b/newsfragments/5273.fixed.md @@ -0,0 +1 @@ +Introspection: Fixes introspection of `__richcmp__`, `__concat__`, `__repeat__`, `__inplace_concat__` and `__inplace_repeat__` \ No newline at end of file diff --git a/pyo3-macros-backend/src/method.rs b/pyo3-macros-backend/src/method.rs index 10aae4f4810..816a9016ae2 100644 --- a/pyo3-macros-backend/src/method.rs +++ b/pyo3-macros-backend/src/method.rs @@ -448,6 +448,7 @@ impl CallingConvention { } } +#[derive(Clone)] pub struct FnSpec<'a> { pub tp: FnType, // Rust function name diff --git a/pyo3-macros-backend/src/pyfunction.rs b/pyo3-macros-backend/src/pyfunction.rs index 5731cb8f510..337a868d412 100644 --- a/pyo3-macros-backend/src/pyfunction.rs +++ b/pyo3-macros-backend/src/pyfunction.rs @@ -96,13 +96,14 @@ pub struct PyFunctionWarningAttribute { pub span: Span, } -#[derive(PartialEq)] +#[derive(PartialEq, Clone)] pub enum PyFunctionWarningCategory { Path(Path), UserWarning, DeprecationWarning, // TODO: unused for now, intended for pyo3(deprecated) special-case } +#[derive(Clone)] pub struct PyFunctionWarning { pub message: LitStr, pub category: PyFunctionWarningCategory, diff --git a/pyo3-macros-backend/src/pyfunction/signature.rs b/pyo3-macros-backend/src/pyfunction/signature.rs index bdf8c4a7cb6..306c42f791a 100644 --- a/pyo3-macros-backend/src/pyfunction/signature.rs +++ b/pyo3-macros-backend/src/pyfunction/signature.rs @@ -266,7 +266,7 @@ impl ConstructorAttribute { } } -#[derive(Default)] +#[derive(Default, Clone)] pub struct PythonSignature { pub positional_parameters: Vec, pub positional_only_parameters: usize, @@ -286,6 +286,7 @@ impl PythonSignature { } } +#[derive(Clone)] pub struct FunctionSignature<'a> { pub arguments: Vec>, pub python_signature: PythonSignature, diff --git a/pyo3-macros-backend/src/pyimpl.rs b/pyo3-macros-backend/src/pyimpl.rs index 6c111774ced..0259c59b8b8 100644 --- a/pyo3-macros-backend/src/pyimpl.rs +++ b/pyo3-macros-backend/src/pyimpl.rs @@ -16,11 +16,12 @@ use crate::{ }; use proc_macro2::TokenStream; use quote::{format_ident, quote}; -use syn::ImplItemFn; +#[cfg(feature = "experimental-inspect")] +use syn::Ident; use syn::{ parse::{Parse, ParseStream}, spanned::Spanned, - Result, + ImplItemFn, Result, }; /// The mechanism used to collect `#[pymethods]` into the type object @@ -354,22 +355,34 @@ fn method_introspection_code(spec: &FnSpec<'_>, parent: &syn::Type, ctx: &Ctx) - let Ctx { pyo3_path, .. } = ctx; let name = spec.python_name.to_string(); - if matches!( - name.as_str(), - "__richcmp__" - | "__concat__" - | "__repeat__" - | "__inplace_concat__" - | "__inplace_repeat__" - | "__getbuffer__" - | "__releasebuffer__" - | "__traverse__" - | "__clear__" - ) { - // This is not a magic Python method, ignore for now - // TODO: properly implement - return quote! {}; + + // __richcmp__ special case + if name == "__richcmp__" { + // We expend into each individual method + return ["__eq__", "__ne__", "__lt__", "__le__", "__gt__", "__ge__"] + .into_iter() + .map(|method_name| { + let mut spec = (*spec).clone(); + spec.python_name = Ident::new(method_name, spec.python_name.span()); + // We remove the CompareOp arg, this is safe because the signature is always the same + // First the other value to compare with then the CompareOp + // We cant to keep the first argument type, hence this hack + spec.signature.arguments.pop(); + spec.signature.python_signature.positional_parameters.pop(); + method_introspection_code(&spec, parent, ctx) + }) + .collect(); } + // We map or ignore some magic methods + // TODO: this might create a naming conflict + let name = match name.as_str() { + "__concat__" => "__add__".into(), + "__repeat__" => "__mul__".into(), + "__inplace_concat__" => "__iadd__".into(), + "__inplace_repeat__" => "__imul__".into(), + "__getbuffer__" | "__releasebuffer__" | "__traverse__" | "__clear__" => return quote! {}, + _ => name, + }; // We introduce self/cls argument and setup decorators let mut first_argument = None; diff --git a/pytests/src/comparisons.rs b/pytests/src/comparisons.rs index 4ed79e42790..9269ec96621 100644 --- a/pytests/src/comparisons.rs +++ b/pytests/src/comparisons.rs @@ -1,3 +1,4 @@ +use pyo3::basic::CompareOp; use pyo3::prelude::*; #[pyclass] @@ -81,6 +82,21 @@ impl Ordered { } } +#[pyclass] +struct OrderedRichCmp(i64); + +#[pymethods] +impl OrderedRichCmp { + #[new] + fn new(value: i64) -> Self { + Self(value) + } + + fn __richcmp__(&self, other: &Self, op: CompareOp) -> bool { + op.matches(self.0.cmp(&other.0)) + } +} + #[pyclass] struct OrderedDefaultNe(i64); @@ -113,11 +129,7 @@ impl OrderedDefaultNe { } #[pymodule(gil_used = false)] -pub fn comparisons(m: &Bound<'_, PyModule>) -> PyResult<()> { - m.add_class::()?; - m.add_class::()?; - m.add_class::()?; - m.add_class::()?; - m.add_class::()?; - Ok(()) +pub mod comparisons { + #[pymodule_export] + use super::{Eq, EqDefaultNe, EqDerived, Ordered, OrderedDefaultNe, OrderedRichCmp}; } diff --git a/pytests/src/lib.rs b/pytests/src/lib.rs index 9a406d9fe05..e5f00701525 100644 --- a/pytests/src/lib.rs +++ b/pytests/src/lib.rs @@ -23,7 +23,9 @@ mod pyo3_pytests { use super::*; #[pymodule_export] - use {consts::consts, pyclasses::pyclasses, pyfunctions::pyfunctions}; + use { + comparisons::comparisons, consts::consts, pyclasses::pyclasses, pyfunctions::pyfunctions, + }; // Inserting to sys.modules allows importing submodules nicely from Python // e.g. import pyo3_pytests.buf_and_str as bas @@ -32,7 +34,6 @@ mod pyo3_pytests { m.add_wrapped(wrap_pymodule!(awaitable::awaitable))?; #[cfg(not(Py_LIMITED_API))] m.add_wrapped(wrap_pymodule!(buf_and_str::buf_and_str))?; - m.add_wrapped(wrap_pymodule!(comparisons::comparisons))?; #[cfg(not(Py_LIMITED_API))] m.add_wrapped(wrap_pymodule!(datetime::datetime))?; m.add_wrapped(wrap_pymodule!(dict_iter::dict_iter))?; diff --git a/pytests/stubs/comparisons.pyi b/pytests/stubs/comparisons.pyi new file mode 100644 index 00000000000..7d4f9d6d31e --- /dev/null +++ b/pytests/stubs/comparisons.pyi @@ -0,0 +1,37 @@ +class Eq: + def __eq__(self, /, other: Eq) -> bool: ... + def __ne__(self, /, other: Eq) -> bool: ... + def __new__(cls, /, value: int) -> None: ... + +class EqDefaultNe: + def __eq__(self, /, other: EqDefaultNe) -> bool: ... + def __new__(cls, /, value: int) -> None: ... + +class EqDerived: + def __new__(cls, /, value: int) -> None: ... + +class Ordered: + def __eq__(self, /, other: Ordered) -> bool: ... + def __ge__(self, /, other: Ordered) -> bool: ... + def __gt__(self, /, other: Ordered) -> bool: ... + def __le__(self, /, other: Ordered) -> bool: ... + def __lt__(self, /, other: Ordered) -> bool: ... + def __ne__(self, /, other: Ordered) -> bool: ... + def __new__(cls, /, value: int) -> None: ... + +class OrderedDefaultNe: + def __eq__(self, /, other: OrderedDefaultNe) -> bool: ... + def __ge__(self, /, other: OrderedDefaultNe) -> bool: ... + def __gt__(self, /, other: OrderedDefaultNe) -> bool: ... + def __le__(self, /, other: OrderedDefaultNe) -> bool: ... + def __lt__(self, /, other: OrderedDefaultNe) -> bool: ... + def __new__(cls, /, value: int) -> None: ... + +class OrderedRichCmp: + def __eq__(self, /, other: OrderedRichCmp) -> bool: ... + def __ge__(self, /, other: OrderedRichCmp) -> bool: ... + def __gt__(self, /, other: OrderedRichCmp) -> bool: ... + def __le__(self, /, other: OrderedRichCmp) -> bool: ... + def __lt__(self, /, other: OrderedRichCmp) -> bool: ... + def __ne__(self, /, other: OrderedRichCmp) -> bool: ... + def __new__(cls, /, value: int) -> None: ... diff --git a/pytests/tests/test_comparisons.py b/pytests/tests/test_comparisons.py index fe4d8f31f62..7b836a60fda 100644 --- a/pytests/tests/test_comparisons.py +++ b/pytests/tests/test_comparisons.py @@ -8,6 +8,7 @@ EqDerived, Ordered, OrderedDefaultNe, + OrderedRichCmp, ) from typing_extensions import Self @@ -132,8 +133,10 @@ def __ge__(self, other: Self) -> bool: return self.x >= other.x -@pytest.mark.parametrize("ty", (Ordered, PyOrdered), ids=("rust", "python")) -def test_ordered(ty: Type[Union[Ordered, PyOrdered]]): +@pytest.mark.parametrize( + "ty", (Ordered, OrderedRichCmp, PyOrdered), ids=("rust", "rust-richcmp", "python") +) +def test_ordered(ty: Type[Union[Ordered, OrderedRichCmp, PyOrdered]]): a = ty(0) b = ty(0) c = ty(1) diff --git a/src/types/boolobject.rs b/src/types/boolobject.rs index a3538c60cc1..c6fca362d64 100644 --- a/src/types/boolobject.rs +++ b/src/types/boolobject.rs @@ -141,6 +141,9 @@ impl<'py> IntoPyObject<'py> for bool { type Output = Borrowed<'py, 'py, Self::Target>; type Error = Infallible; + #[cfg(feature = "experimental-inspect")] + const OUTPUT_TYPE: &'static str = "bool"; + #[inline] fn into_pyobject(self, py: Python<'py>) -> Result { Ok(PyBool::new(py, self)) @@ -157,6 +160,9 @@ impl<'py> IntoPyObject<'py> for &bool { type Output = Borrowed<'py, 'py, Self::Target>; type Error = Infallible; + #[cfg(feature = "experimental-inspect")] + const OUTPUT_TYPE: &'static str = bool::OUTPUT_TYPE; + #[inline] fn into_pyobject(self, py: Python<'py>) -> Result { (*self).into_pyobject(py)