Skip to content

Commit e42bca5

Browse files
committed
Refine some magic methods introspection
- __richcmp__ is mapped to le/lt/ge/gt/eq/ne - __concat__, __repeat__, __inplace_concat__ and __inplace_repeat__ are mapped to their Python equivalent
1 parent de0bbee commit e42bca5

File tree

10 files changed

+104
-29
lines changed

10 files changed

+104
-29
lines changed

newsfragments/5273.fixed.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Introspection: Fixes introspection of `__richcmp__`, `__concat__`, `__repeat__`, `__inplace_concat__` and `__inplace_repeat__`

pyo3-macros-backend/src/method.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -448,6 +448,7 @@ impl CallingConvention {
448448
}
449449
}
450450

451+
#[derive(Clone)]
451452
pub struct FnSpec<'a> {
452453
pub tp: FnType,
453454
// Rust function name

pyo3-macros-backend/src/pyfunction.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,13 +96,14 @@ pub struct PyFunctionWarningAttribute {
9696
pub span: Span,
9797
}
9898

99-
#[derive(PartialEq)]
99+
#[derive(PartialEq, Clone)]
100100
pub enum PyFunctionWarningCategory {
101101
Path(Path),
102102
UserWarning,
103103
DeprecationWarning, // TODO: unused for now, intended for pyo3(deprecated) special-case
104104
}
105105

106+
#[derive(Clone)]
106107
pub struct PyFunctionWarning {
107108
pub message: LitStr,
108109
pub category: PyFunctionWarningCategory,

pyo3-macros-backend/src/pyfunction/signature.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ impl ConstructorAttribute {
266266
}
267267
}
268268

269-
#[derive(Default)]
269+
#[derive(Default, Clone)]
270270
pub struct PythonSignature {
271271
pub positional_parameters: Vec<String>,
272272
pub positional_only_parameters: usize,
@@ -286,6 +286,7 @@ impl PythonSignature {
286286
}
287287
}
288288

289+
#[derive(Clone)]
289290
pub struct FunctionSignature<'a> {
290291
pub arguments: Vec<FnArg<'a>>,
291292
pub python_signature: PythonSignature,

pyo3-macros-backend/src/pyimpl.rs

Lines changed: 28 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ use crate::{
1414
self, is_proto_method, GeneratedPyMethod, MethodAndMethodDef, MethodAndSlotDef, PyMethod,
1515
},
1616
};
17-
use proc_macro2::TokenStream;
17+
use proc_macro2::{Ident, TokenStream};
1818
use quote::{format_ident, quote};
1919
use syn::ImplItemFn;
2020
use syn::{
@@ -354,22 +354,34 @@ fn method_introspection_code(spec: &FnSpec<'_>, parent: &syn::Type, ctx: &Ctx) -
354354
let Ctx { pyo3_path, .. } = ctx;
355355

356356
let name = spec.python_name.to_string();
357-
if matches!(
358-
name.as_str(),
359-
"__richcmp__"
360-
| "__concat__"
361-
| "__repeat__"
362-
| "__inplace_concat__"
363-
| "__inplace_repeat__"
364-
| "__getbuffer__"
365-
| "__releasebuffer__"
366-
| "__traverse__"
367-
| "__clear__"
368-
) {
369-
// This is not a magic Python method, ignore for now
370-
// TODO: properly implement
371-
return quote! {};
357+
358+
// __richcmp__ special case
359+
if name == "__richcmp__" {
360+
// We expend into each individual method
361+
return ["__eq__", "__ne__", "__lt__", "__le__", "__gt__", "__ge__"]
362+
.into_iter()
363+
.map(|method_name| {
364+
let mut spec = (*spec).clone();
365+
spec.python_name = Ident::new(method_name, spec.python_name.span());
366+
// We remove the CompareOp arg, this is safe because the signature is always the same
367+
// First the other value to compare with then the CompareOp
368+
// We cant to keep the first argument type, hence this hack
369+
spec.signature.arguments.pop();
370+
spec.signature.python_signature.positional_parameters.pop();
371+
method_introspection_code(&spec, parent, ctx)
372+
})
373+
.collect();
372374
}
375+
// We map or ignore some magic methods
376+
// TODO: this might create a naming conflict
377+
let name = match name.as_str() {
378+
"__concat__" => "__add__".into(),
379+
"__repeat__" => "__mul__".into(),
380+
"__inplace_concat__" => "__iadd__".into(),
381+
"__inplace_repeat__" => "__imul__".into(),
382+
"__getbuffer__" | "__releasebuffer__" | "__traverse__" | "__clear__" => return quote! {},
383+
_ => name,
384+
};
373385

374386
// We introduce self/cls argument and setup decorators
375387
let mut first_argument = None;

pytests/src/comparisons.rs

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use pyo3::basic::CompareOp;
12
use pyo3::prelude::*;
23

34
#[pyclass]
@@ -81,6 +82,21 @@ impl Ordered {
8182
}
8283
}
8384

85+
#[pyclass]
86+
struct OrderedRichCmp(i64);
87+
88+
#[pymethods]
89+
impl OrderedRichCmp {
90+
#[new]
91+
fn new(value: i64) -> Self {
92+
Self(value)
93+
}
94+
95+
fn __richcmp__(&self, other: &Self, op: CompareOp) -> bool {
96+
op.matches(self.0.cmp(&other.0))
97+
}
98+
}
99+
84100
#[pyclass]
85101
struct OrderedDefaultNe(i64);
86102

@@ -113,11 +129,7 @@ impl OrderedDefaultNe {
113129
}
114130

115131
#[pymodule(gil_used = false)]
116-
pub fn comparisons(m: &Bound<'_, PyModule>) -> PyResult<()> {
117-
m.add_class::<Eq>()?;
118-
m.add_class::<EqDefaultNe>()?;
119-
m.add_class::<EqDerived>()?;
120-
m.add_class::<Ordered>()?;
121-
m.add_class::<OrderedDefaultNe>()?;
122-
Ok(())
132+
pub mod comparisons {
133+
#[pymodule_export]
134+
use super::{Eq, EqDefaultNe, EqDerived, Ordered, OrderedDefaultNe, OrderedRichCmp};
123135
}

pytests/src/lib.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@ mod pyo3_pytests {
2323
use super::*;
2424

2525
#[pymodule_export]
26-
use {consts::consts, pyclasses::pyclasses, pyfunctions::pyfunctions};
26+
use {
27+
comparisons::comparisons, consts::consts, pyclasses::pyclasses, pyfunctions::pyfunctions,
28+
};
2729

2830
// Inserting to sys.modules allows importing submodules nicely from Python
2931
// e.g. import pyo3_pytests.buf_and_str as bas
@@ -32,7 +34,6 @@ mod pyo3_pytests {
3234
m.add_wrapped(wrap_pymodule!(awaitable::awaitable))?;
3335
#[cfg(not(Py_LIMITED_API))]
3436
m.add_wrapped(wrap_pymodule!(buf_and_str::buf_and_str))?;
35-
m.add_wrapped(wrap_pymodule!(comparisons::comparisons))?;
3637
#[cfg(not(Py_LIMITED_API))]
3738
m.add_wrapped(wrap_pymodule!(datetime::datetime))?;
3839
m.add_wrapped(wrap_pymodule!(dict_iter::dict_iter))?;

pytests/stubs/comparisons.pyi

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
class Eq:
2+
def __eq__(self, /, other: Eq) -> bool: ...
3+
def __ne__(self, /, other: Eq) -> bool: ...
4+
def __new__(cls, /, value: int) -> None: ...
5+
6+
class EqDefaultNe:
7+
def __eq__(self, /, other: EqDefaultNe) -> bool: ...
8+
def __new__(cls, /, value: int) -> None: ...
9+
10+
class EqDerived:
11+
def __new__(cls, /, value: int) -> None: ...
12+
13+
class Ordered:
14+
def __eq__(self, /, other: Ordered) -> bool: ...
15+
def __ge__(self, /, other: Ordered) -> bool: ...
16+
def __gt__(self, /, other: Ordered) -> bool: ...
17+
def __le__(self, /, other: Ordered) -> bool: ...
18+
def __lt__(self, /, other: Ordered) -> bool: ...
19+
def __ne__(self, /, other: Ordered) -> bool: ...
20+
def __new__(cls, /, value: int) -> None: ...
21+
22+
class OrderedDefaultNe:
23+
def __eq__(self, /, other: OrderedDefaultNe) -> bool: ...
24+
def __ge__(self, /, other: OrderedDefaultNe) -> bool: ...
25+
def __gt__(self, /, other: OrderedDefaultNe) -> bool: ...
26+
def __le__(self, /, other: OrderedDefaultNe) -> bool: ...
27+
def __lt__(self, /, other: OrderedDefaultNe) -> bool: ...
28+
def __new__(cls, /, value: int) -> None: ...
29+
30+
class OrderedRichCmp:
31+
def __eq__(self, /, other: OrderedRichCmp) -> bool: ...
32+
def __ge__(self, /, other: OrderedRichCmp) -> bool: ...
33+
def __gt__(self, /, other: OrderedRichCmp) -> bool: ...
34+
def __le__(self, /, other: OrderedRichCmp) -> bool: ...
35+
def __lt__(self, /, other: OrderedRichCmp) -> bool: ...
36+
def __ne__(self, /, other: OrderedRichCmp) -> bool: ...
37+
def __new__(cls, /, value: int) -> None: ...

pytests/tests/test_comparisons.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
EqDerived,
99
Ordered,
1010
OrderedDefaultNe,
11+
OrderedRichCmp,
1112
)
1213
from typing_extensions import Self
1314

@@ -132,8 +133,10 @@ def __ge__(self, other: Self) -> bool:
132133
return self.x >= other.x
133134

134135

135-
@pytest.mark.parametrize("ty", (Ordered, PyOrdered), ids=("rust", "python"))
136-
def test_ordered(ty: Type[Union[Ordered, PyOrdered]]):
136+
@pytest.mark.parametrize(
137+
"ty", (Ordered, OrderedRichCmp, PyOrdered), ids=("rust", "rust-richcmp", "python")
138+
)
139+
def test_ordered(ty: Type[Union[Ordered, OrderedRichCmp, PyOrdered]]):
137140
a = ty(0)
138141
b = ty(0)
139142
c = ty(1)

src/types/boolobject.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,9 @@ impl<'py> IntoPyObject<'py> for bool {
141141
type Output = Borrowed<'py, 'py, Self::Target>;
142142
type Error = Infallible;
143143

144+
#[cfg(feature = "experimental-inspect")]
145+
const OUTPUT_TYPE: &'static str = "bool";
146+
144147
#[inline]
145148
fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
146149
Ok(PyBool::new(py, self))
@@ -157,6 +160,9 @@ impl<'py> IntoPyObject<'py> for &bool {
157160
type Output = Borrowed<'py, 'py, Self::Target>;
158161
type Error = Infallible;
159162

163+
#[cfg(feature = "experimental-inspect")]
164+
const OUTPUT_TYPE: &'static str = bool::OUTPUT_TYPE;
165+
160166
#[inline]
161167
fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
162168
(*self).into_pyobject(py)

0 commit comments

Comments
 (0)