Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make PyClassBorrowChecker thread safe #4544

Merged
merged 22 commits into from
Oct 5, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions newsfragments/4544.changed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
* Refactored runtime borrow checking for mutable pyclass instances
to be thread-safe when the GIL is disabled.
26 changes: 26 additions & 0 deletions pytests/src/pyclasses.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::{thread, time};

use pyo3::exceptions::{PyStopIteration, PyValueError};
use pyo3::prelude::*;
use pyo3::types::PyType;
Expand Down Expand Up @@ -43,6 +45,29 @@ impl PyClassIter {
}
}

#[pyclass]
#[derive(Default)]
struct PyClassThreadIter {
count: usize,
}

#[pymethods]
impl PyClassThreadIter {
#[new]
pub fn new() -> Self {
Default::default()
}

fn __next__(&mut self, py: Python<'_>) -> usize {
let current_count = self.count;
self.count += 1;
ngoldbaum marked this conversation as resolved.
Show resolved Hide resolved
if current_count == 0 {
py.allow_threads(|| thread::sleep(time::Duration::from_millis(100)));
}
self.count
}
}

/// Demonstrates a base class which can operate on the relevant subclass in its constructor.
#[pyclass(subclass)]
#[derive(Clone, Debug)]
Expand Down Expand Up @@ -83,6 +108,7 @@ impl ClassWithDict {
pub fn pyclasses(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<EmptyClass>()?;
m.add_class::<PyClassIter>()?;
m.add_class::<PyClassThreadIter>()?;
m.add_class::<AssertingBaseClass>()?;
m.add_class::<ClassWithoutConstructor>()?;
#[cfg(any(Py_3_10, not(Py_LIMITED_API)))]
Expand Down
16 changes: 16 additions & 0 deletions pytests/tests/test_pyclasses.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import concurrent.futures
from typing import Type

import pytest
Expand Down Expand Up @@ -53,6 +54,21 @@ def test_iter():
assert excinfo.value.value == "Ended"


def test_parallel_iter():
i = pyclasses.PyClassThreadIter()

def func():
next(i)

# the second thread attempts to borrow a reference to the instance's
# state while the first thread is still sleeping, so we trigger a
# runtime borrow-check error
with pytest.raises(RuntimeError, match="Already borrowed"):
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as tpe:
futures = [tpe.submit(func), tpe.submit(func)]
[f.result() for f in futures]


class AssertingSubClass(pyclasses.AssertingBaseClass):
pass

Expand Down
90 changes: 63 additions & 27 deletions src/pycell/impl_.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
#![allow(missing_docs)]
//! Crate-private implementation of PyClassObject

use std::cell::{Cell, UnsafeCell};
use std::cell::UnsafeCell;
use std::marker::PhantomData;
use std::mem::ManuallyDrop;
use std::sync::atomic::{AtomicUsize, Ordering};

use crate::impl_::pyclass::{
PyClassBaseType, PyClassDict, PyClassImpl, PyClassThreadChecker, PyClassWeakRef,
Expand Down Expand Up @@ -50,22 +51,54 @@ impl<M: PyClassMutability> PyClassMutability for ExtendsMutableAncestor<M> {
type MutableChild = ExtendsMutableAncestor<MutableClass>;
}

#[derive(Debug, Copy, Clone, Eq, PartialEq)]
struct BorrowFlag(usize);
#[derive(Debug)]
struct BorrowFlag(AtomicUsize);

impl BorrowFlag {
pub(crate) const UNUSED: BorrowFlag = BorrowFlag(0);
const HAS_MUTABLE_BORROW: BorrowFlag = BorrowFlag(usize::MAX);
const fn increment(self) -> Self {
Self(self.0 + 1)
pub(crate) const UNUSED: usize = 0;
const HAS_MUTABLE_BORROW: usize = usize::MAX;
fn increment(&self) -> Result<(), PyBorrowError> {
let mut value = self.0.load(Ordering::Relaxed);
if value == BorrowFlag::HAS_MUTABLE_BORROW {
return Err(PyBorrowError { _private: () });
}
loop {
match self.0.compare_exchange(
// only increment if the value hasn't changed since the
// last atomic load
value,
value + 1,
// on success, the write is synchronized to ensure other threads
// can't acquire any references
Ordering::Release,
// on failure, the read is synchronized to ensure the borrowed reference
// state is observed
Ordering::Acquire,
) {
Ok(..) => {
// value successfully incremented
break Ok(());
}
Err(changed_value) => {
// value changed under us, need to try again
if changed_value == BorrowFlag::HAS_MUTABLE_BORROW {
return Err(PyBorrowError { _private: () });
}
value = changed_value;
}
}
}
ngoldbaum marked this conversation as resolved.
Show resolved Hide resolved
}
const fn decrement(self) -> Self {
Self(self.0 - 1)
fn decrement(&self) {
// impossible to get into a bad state from here so relaxed
// ordering is fine, the decrement only needs to eventually
// be visible
self.0.fetch_sub(1, Ordering::Relaxed);
}
}

pub struct EmptySlot(());
pub struct BorrowChecker(Cell<BorrowFlag>);
pub struct BorrowChecker(BorrowFlag);

pub trait PyClassBorrowChecker {
/// Initial value for self
Expand Down Expand Up @@ -110,36 +143,39 @@ impl PyClassBorrowChecker for EmptySlot {
impl PyClassBorrowChecker for BorrowChecker {
#[inline]
fn new() -> Self {
Self(Cell::new(BorrowFlag::UNUSED))
Self(BorrowFlag(AtomicUsize::new(BorrowFlag::UNUSED)))
}

fn try_borrow(&self) -> Result<(), PyBorrowError> {
let flag = self.0.get();
if flag != BorrowFlag::HAS_MUTABLE_BORROW {
self.0.set(flag.increment());
Ok(())
} else {
Err(PyBorrowError { _private: () })
}
self.0.increment()
}

fn release_borrow(&self) {
let flag = self.0.get();
self.0.set(flag.decrement())
self.0.decrement();
}

fn try_borrow_mut(&self) -> Result<(), PyBorrowMutError> {
let flag = self.0.get();
if flag == BorrowFlag::UNUSED {
self.0.set(BorrowFlag::HAS_MUTABLE_BORROW);
Ok(())
} else {
Err(PyBorrowMutError { _private: () })
let flag = &self.0;
match flag.0.compare_exchange(
// only allowed to transition to mutable borrow if the reference is
// currently unused
BorrowFlag::UNUSED,
BorrowFlag::HAS_MUTABLE_BORROW,
// On success the read is synchronized to ensure other
// threads don't get a reference before this thread checks
// that it can get one
Ordering::Acquire,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this need to be AcqRel?

The documentation for compare_exchange states:

Using Acquire as success ordering makes the store part of this operation Relaxed

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not totally convinced it's necessary to establish a happens-before relationship with other threads when they observe the successfully stored flag (as they won't be allowed to read the accompanying data), but I think this also cannot hurt.

// It doesn't matter precisely when the failure gets turned
// into an error
Ordering::Relaxed,
) {
Ok(..) => Ok(()),
Err(..) => Err(PyBorrowMutError { _private: () }),
}
}

fn release_borrow_mut(&self) {
self.0.set(BorrowFlag::UNUSED)
self.0 .0.store(BorrowFlag::UNUSED, Ordering::Release)
}
}

Expand Down
Loading