Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added as_super methods to PyRef and PyRefMut. #4219

Merged
merged 16 commits into from
Jun 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 33 additions & 6 deletions guide/src/class.md
Original file line number Diff line number Diff line change
Expand Up @@ -327,8 +327,12 @@ explicitly.

To get a parent class from a child, use [`PyRef`] instead of `&self` for methods,
or [`PyRefMut`] instead of `&mut self`.
Then you can access a parent class by `self_.as_ref()` as `&Self::BaseClass`,
or by `self_.into_super()` as `PyRef<Self::BaseClass>`.
Then you can access a parent class by `self_.as_super()` as `&PyRef<Self::BaseClass>`,
or by `self_.into_super()` as `PyRef<Self::BaseClass>` (and similar for the `PyRefMut`
case). For convenience, `self_.as_ref()` can also be used to get `&Self::BaseClass`
directly; however, this approach does not let you access base clases higher in the
inheritance hierarchy, for which you would need to chain multiple `as_super` or
`into_super` calls.

```rust
# use pyo3::prelude::*;
Expand All @@ -345,7 +349,7 @@ impl BaseClass {
BaseClass { val1: 10 }
}

pub fn method(&self) -> PyResult<usize> {
pub fn method1(&self) -> PyResult<usize> {
Ok(self.val1)
}
}
Expand All @@ -363,8 +367,8 @@ impl SubClass {
}

fn method2(self_: PyRef<'_, Self>) -> PyResult<usize> {
let super_ = self_.as_ref(); // Get &BaseClass
super_.method().map(|x| x * self_.val2)
let super_ = self_.as_super(); // Get &PyRef<BaseClass>
super_.method1().map(|x| x * self_.val2)
}
}

Expand All @@ -381,11 +385,28 @@ impl SubSubClass {
}

fn method3(self_: PyRef<'_, Self>) -> PyResult<usize> {
let base = self_.as_super().as_super(); // Get &PyRef<'_, BaseClass>
base.method1().map(|x| x * self_.val3)
}

fn method4(self_: PyRef<'_, Self>) -> PyResult<usize> {
let v = self_.val3;
let super_ = self_.into_super(); // Get PyRef<'_, SubClass>
SubClass::method2(super_).map(|x| x * v)
}

fn get_values(self_: PyRef<'_, Self>) -> (usize, usize, usize) {
let val1 = self_.as_super().as_super().val1;
let val2 = self_.as_super().val2;
(val1, val2, self_.val3)
}

fn double_values(mut self_: PyRefMut<'_, Self>) {
self_.as_super().as_super().val1 *= 2;
self_.as_super().val2 *= 2;
self_.val3 *= 2;
}

#[staticmethod]
fn factory_method(py: Python<'_>, val: usize) -> PyResult<PyObject> {
let base = PyClassInitializer::from(BaseClass::new());
Expand All @@ -400,7 +421,13 @@ impl SubSubClass {
}
# Python::with_gil(|py| {
# let subsub = pyo3::Py::new(py, SubSubClass::new()).unwrap();
# pyo3::py_run!(py, subsub, "assert subsub.method3() == 3000");
# pyo3::py_run!(py, subsub, "assert subsub.method1() == 10");
# pyo3::py_run!(py, subsub, "assert subsub.method2() == 150");
# pyo3::py_run!(py, subsub, "assert subsub.method3() == 200");
# pyo3::py_run!(py, subsub, "assert subsub.method4() == 3000");
# pyo3::py_run!(py, subsub, "assert subsub.get_values() == (10, 15, 20)");
# pyo3::py_run!(py, subsub, "assert subsub.double_values() == None");
# pyo3::py_run!(py, subsub, "assert subsub.get_values() == (20, 30, 40)");
# let subsub = SubSubClass::factory_method(py, 2).unwrap();
# let subsubsub = SubSubClass::factory_method(py, 3).unwrap();
# let cls = py.get_type_bound::<SubSubClass>();
Expand Down
3 changes: 3 additions & 0 deletions newsfragments/4219.added.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
- Added `as_super` methods to `PyRef` and `PyRefMut` for accesing the base class by reference
- Updated user guide to recommend `as_super` for referencing the base class instead of `as_ref`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to include docs changes or internal changes here, but as this PR is otherwise great let's just leave these points here and I'll tidy up during release.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good to know, I'll keep that in mind next time. Thanks for the merge!

- Added `pyo3::internal_tricks::ptr_from_mut` function for casting `&mut T` to `*mut T`
6 changes: 6 additions & 0 deletions src/internal_tricks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -223,3 +223,9 @@ pub(crate) fn extract_c_string(
pub(crate) const fn ptr_from_ref<T>(t: &T) -> *const T {
t as *const T
}

// TODO: use ptr::from_mut on MSRV 1.76
#[inline]
pub(crate) fn ptr_from_mut<T>(t: &mut T) -> *mut T {
t as *mut T
}
168 changes: 164 additions & 4 deletions src/pycell.rs
Original file line number Diff line number Diff line change
Expand Up @@ -196,13 +196,13 @@
use crate::conversion::AsPyPointer;
use crate::exceptions::PyRuntimeError;
use crate::ffi_ptr_ext::FfiPtrExt;
use crate::internal_tricks::{ptr_from_mut, ptr_from_ref};
use crate::pyclass::{boolean_struct::False, PyClass};
use crate::types::any::PyAnyMethods;
#[cfg(feature = "gil-refs")]
use crate::{
conversion::ToPyObject,
impl_::pyclass::PyClassImpl,
internal_tricks::ptr_from_ref,
pyclass::boolean_struct::True,
pyclass_init::PyClassInitializer,
type_object::{PyLayout, PySizedLayout},
Expand Down Expand Up @@ -612,6 +612,7 @@ impl<T: PyClass + fmt::Debug> fmt::Debug for PyCell<T> {
/// ```
///
/// See the [module-level documentation](self) for more information.
#[repr(transparent)]
pub struct PyRef<'p, T: PyClass> {
// TODO: once the GIL Ref API is removed, consider adding a lifetime parameter to `PyRef` to
// store `Borrowed` here instead, avoiding reference counting overhead.
Expand All @@ -631,7 +632,7 @@ where
U: PyClass,
{
fn as_ref(&self) -> &T::BaseType {
unsafe { &*self.inner.get_class_object().ob_base.get_ptr() }
self.as_super()
}
}

Expand Down Expand Up @@ -743,6 +744,58 @@ where
},
}
}

/// Borrows a shared reference to `PyRef<T::BaseType>`.
///
/// With the help of this method, you can access attributes and call methods
/// on the superclass without consuming the `PyRef<T>`. This method can also
/// be chained to access the super-superclass (and so on).
///
/// # Examples
/// ```
/// # use pyo3::prelude::*;
/// #[pyclass(subclass)]
/// struct Base {
/// base_name: &'static str,
/// }
/// #[pymethods]
/// impl Base {
/// fn base_name_len(&self) -> usize {
/// self.base_name.len()
/// }
/// }
///
/// #[pyclass(extends=Base)]
/// struct Sub {
/// sub_name: &'static str,
/// }
///
/// #[pymethods]
/// impl Sub {
/// #[new]
/// fn new() -> (Self, Base) {
/// (Self { sub_name: "sub_name" }, Base { base_name: "base_name" })
/// }
/// fn sub_name_len(&self) -> usize {
/// self.sub_name.len()
/// }
/// fn format_name_lengths(slf: PyRef<'_, Self>) -> String {
/// format!("{} {}", slf.as_super().base_name_len(), slf.sub_name_len())
/// }
/// }
/// # Python::with_gil(|py| {
/// # let sub = Py::new(py, Sub::new()).unwrap();
/// # pyo3::py_run!(py, sub, "assert sub.format_name_lengths() == '9 8'")
/// # });
/// ```
pub fn as_super(&self) -> &PyRef<'p, U> {
let ptr = ptr_from_ref::<Bound<'p, T>>(&self.inner)
// `Bound<T>` has the same layout as `Bound<T::BaseType>`
.cast::<Bound<'p, T::BaseType>>()
// `Bound<T::BaseType>` has the same layout as `PyRef<T::BaseType>`
.cast::<PyRef<'p, T::BaseType>>();
unsafe { &*ptr }
}
}

impl<'p, T: PyClass> Deref for PyRef<'p, T> {
Expand Down Expand Up @@ -799,6 +852,7 @@ impl<T: PyClass + fmt::Debug> fmt::Debug for PyRef<'_, T> {
/// A wrapper type for a mutably borrowed value from a [`Bound<'py, T>`].
///
/// See the [module-level documentation](self) for more information.
#[repr(transparent)]
pub struct PyRefMut<'p, T: PyClass<Frozen = False>> {
// TODO: once the GIL Ref API is removed, consider adding a lifetime parameter to `PyRef` to
// store `Borrowed` here instead, avoiding reference counting overhead.
Expand All @@ -818,7 +872,7 @@ where
U: PyClass<Frozen = False>,
{
fn as_ref(&self) -> &T::BaseType {
unsafe { &*self.inner.get_class_object().ob_base.get_ptr() }
PyRefMut::downgrade(self).as_super()
}
}

Expand All @@ -828,7 +882,7 @@ where
U: PyClass<Frozen = False>,
{
fn as_mut(&mut self) -> &mut T::BaseType {
unsafe { &mut *self.inner.get_class_object().ob_base.get_ptr() }
self.as_super()
}
}

Expand Down Expand Up @@ -870,6 +924,11 @@ impl<'py, T: PyClass<Frozen = False>> PyRefMut<'py, T> {
.try_borrow_mut()
.map(|_| Self { inner: obj.clone() })
}

pub(crate) fn downgrade(slf: &Self) -> &PyRef<'py, T> {
// `PyRefMut<T>` and `PyRef<T>` have the same layout
unsafe { &*ptr_from_ref(slf).cast() }
}
}

impl<'p, T, U> PyRefMut<'p, T>
Expand All @@ -891,6 +950,23 @@ where
},
}
}

/// Borrows a mutable reference to `PyRefMut<T::BaseType>`.
///
/// With the help of this method, you can mutate attributes and call mutating
/// methods on the superclass without consuming the `PyRefMut<T>`. This method
/// can also be chained to access the super-superclass (and so on).
///
/// See [`PyRef::as_super`] for more.
pub fn as_super(&mut self) -> &mut PyRefMut<'p, U> {
let ptr = ptr_from_mut::<Bound<'p, T>>(&mut self.inner)
// `Bound<T>` has the same layout as `Bound<T::BaseType>`
.cast::<Bound<'p, T::BaseType>>()
// `Bound<T::BaseType>` has the same layout as `PyRefMut<T::BaseType>`,
// and the mutable borrow on `self` prevents aliasing
.cast::<PyRefMut<'p, T::BaseType>>();
unsafe { &mut *ptr }
}
}

impl<'p, T: PyClass<Frozen = False>> Deref for PyRefMut<'p, T> {
Expand Down Expand Up @@ -1140,4 +1216,88 @@ mod tests {
unsafe { ffi::Py_DECREF(ptr) };
})
}

#[crate::pyclass]
#[pyo3(crate = "crate", subclass)]
struct BaseClass {
val1: usize,
}

#[crate::pyclass]
#[pyo3(crate = "crate", extends=BaseClass, subclass)]
struct SubClass {
val2: usize,
}

#[crate::pyclass]
#[pyo3(crate = "crate", extends=SubClass)]
struct SubSubClass {
val3: usize,
}

#[crate::pymethods]
#[pyo3(crate = "crate")]
impl SubSubClass {
#[new]
fn new(py: Python<'_>) -> crate::Py<SubSubClass> {
let init = crate::PyClassInitializer::from(BaseClass { val1: 10 })
.add_subclass(SubClass { val2: 15 })
.add_subclass(SubSubClass { val3: 20 });
crate::Py::new(py, init).expect("allocation error")
}

fn get_values(self_: PyRef<'_, Self>) -> (usize, usize, usize) {
let val1 = self_.as_super().as_super().val1;
let val2 = self_.as_super().val2;
(val1, val2, self_.val3)
}

fn double_values(mut self_: PyRefMut<'_, Self>) {
self_.as_super().as_super().val1 *= 2;
self_.as_super().val2 *= 2;
self_.val3 *= 2;
}
}

#[test]
fn test_pyref_as_super() {
Python::with_gil(|py| {
let obj = SubSubClass::new(py).into_bound(py);
let pyref = obj.borrow();
assert_eq!(pyref.as_super().as_super().val1, 10);
assert_eq!(pyref.as_super().val2, 15);
assert_eq!(pyref.as_ref().val2, 15); // `as_ref` also works
assert_eq!(pyref.val3, 20);
assert_eq!(SubSubClass::get_values(pyref), (10, 15, 20));
});
}

#[test]
fn test_pyrefmut_as_super() {
Python::with_gil(|py| {
let obj = SubSubClass::new(py).into_bound(py);
assert_eq!(SubSubClass::get_values(obj.borrow()), (10, 15, 20));
{
let mut pyrefmut = obj.borrow_mut();
assert_eq!(pyrefmut.as_super().as_ref().val1, 10);
pyrefmut.as_super().as_super().val1 -= 5;
pyrefmut.as_super().val2 -= 3;
pyrefmut.as_mut().val2 -= 2; // `as_mut` also works
pyrefmut.val3 -= 5;
}
assert_eq!(SubSubClass::get_values(obj.borrow()), (5, 10, 15));
SubSubClass::double_values(obj.borrow_mut());
assert_eq!(SubSubClass::get_values(obj.borrow()), (10, 20, 30));
});
}

#[test]
fn test_pyrefs_in_python() {
Python::with_gil(|py| {
let obj = SubSubClass::new(py);
crate::py_run!(py, obj, "assert obj.get_values() == (10, 15, 20)");
crate::py_run!(py, obj, "assert obj.double_values() is None");
crate::py_run!(py, obj, "assert obj.get_values() == (20, 30, 40)");
});
}
}
Loading