Skip to content

Commit

Permalink
reintroduce PyBuffer::get (#4486)
Browse files Browse the repository at this point in the history
  • Loading branch information
Icxolu committed Aug 24, 2024
1 parent 7d399ff commit 8446937
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 14 deletions.
2 changes: 1 addition & 1 deletion pytests/src/buf_and_str.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ impl BytesExtractor {

#[staticmethod]
pub fn from_buffer(buf: &Bound<'_, PyAny>) -> PyResult<usize> {
let buf = PyBuffer::<u8>::get_bound(buf)?;
let buf = PyBuffer::<u8>::get(buf)?;
Ok(buf.item_count())
}
}
Expand Down
19 changes: 13 additions & 6 deletions src/buffer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,13 +184,13 @@ pub unsafe trait Element: Copy {

impl<'py, T: Element> FromPyObject<'py> for PyBuffer<T> {
fn extract_bound(obj: &Bound<'_, PyAny>) -> PyResult<PyBuffer<T>> {
Self::get_bound(obj)
Self::get(obj)
}
}

impl<T: Element> PyBuffer<T> {
/// Gets the underlying buffer from the specified python object.
pub fn get_bound(obj: &Bound<'_, PyAny>) -> PyResult<PyBuffer<T>> {
pub fn get(obj: &Bound<'_, PyAny>) -> PyResult<PyBuffer<T>> {
// TODO: use nightly API Box::new_uninit() once stable
let mut buf = Box::new(mem::MaybeUninit::uninit());
let buf: Box<ffi::Py_buffer> = {
Expand Down Expand Up @@ -224,6 +224,13 @@ impl<T: Element> PyBuffer<T> {
}
}

/// Deprecated name for [`PyBuffer::get`].
#[deprecated(since = "0.23.0", note = "renamed to `PyBuffer::get`")]
#[inline]
pub fn get_bound(obj: &Bound<'_, PyAny>) -> PyResult<PyBuffer<T>> {
Self::get(obj)
}

/// Gets the pointer to the start of the buffer memory.
///
/// Warning: the buffer memory might be mutated by other Python functions,
Expand Down Expand Up @@ -686,7 +693,7 @@ mod tests {
fn test_debug() {
Python::with_gil(|py| {
let bytes = py.eval(ffi::c_str!("b'abcde'"), None, None).unwrap();
let buffer: PyBuffer<u8> = PyBuffer::get_bound(&bytes).unwrap();
let buffer: PyBuffer<u8> = PyBuffer::get(&bytes).unwrap();
let expected = format!(
concat!(
"PyBuffer {{ buf: {:?}, obj: {:?}, ",
Expand Down Expand Up @@ -848,7 +855,7 @@ mod tests {
fn test_bytes_buffer() {
Python::with_gil(|py| {
let bytes = py.eval(ffi::c_str!("b'abcde'"), None, None).unwrap();
let buffer = PyBuffer::get_bound(&bytes).unwrap();
let buffer = PyBuffer::get(&bytes).unwrap();
assert_eq!(buffer.dimensions(), 1);
assert_eq!(buffer.item_count(), 5);
assert_eq!(buffer.format().to_str().unwrap(), "B");
Expand Down Expand Up @@ -884,7 +891,7 @@ mod tests {
.unwrap()
.call_method("array", ("f", (1.0, 1.5, 2.0, 2.5)), None)
.unwrap();
let buffer = PyBuffer::get_bound(&array).unwrap();
let buffer = PyBuffer::get(&array).unwrap();
assert_eq!(buffer.dimensions(), 1);
assert_eq!(buffer.item_count(), 4);
assert_eq!(buffer.format().to_str().unwrap(), "f");
Expand Down Expand Up @@ -914,7 +921,7 @@ mod tests {
assert_eq!(buffer.to_vec(py).unwrap(), [10.0, 11.0, 12.0, 13.0]);

// F-contiguous fns
let buffer = PyBuffer::get_bound(&array).unwrap();
let buffer = PyBuffer::get(&array).unwrap();
let slice = buffer.as_fortran_slice(py).unwrap();
assert_eq!(slice.len(), 4);
assert_eq!(slice[1].get(), 11.0);
Expand Down
12 changes: 6 additions & 6 deletions tests/test_buffer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,43 +95,43 @@ fn test_get_buffer_errors() {
)
.unwrap();

assert!(PyBuffer::<u32>::get_bound(instance.bind(py)).is_ok());
assert!(PyBuffer::<u32>::get(instance.bind(py)).is_ok());

instance.borrow_mut(py).error = Some(TestGetBufferError::NullShape);
assert_eq!(
PyBuffer::<u32>::get_bound(instance.bind(py))
PyBuffer::<u32>::get(instance.bind(py))
.unwrap_err()
.to_string(),
"BufferError: shape is null"
);

instance.borrow_mut(py).error = Some(TestGetBufferError::NullStrides);
assert_eq!(
PyBuffer::<u32>::get_bound(instance.bind(py))
PyBuffer::<u32>::get(instance.bind(py))
.unwrap_err()
.to_string(),
"BufferError: strides is null"
);

instance.borrow_mut(py).error = Some(TestGetBufferError::IncorrectItemSize);
assert_eq!(
PyBuffer::<u32>::get_bound(instance.bind(py))
PyBuffer::<u32>::get(instance.bind(py))
.unwrap_err()
.to_string(),
"BufferError: buffer contents are not compatible with u32"
);

instance.borrow_mut(py).error = Some(TestGetBufferError::IncorrectFormat);
assert_eq!(
PyBuffer::<u32>::get_bound(instance.bind(py))
PyBuffer::<u32>::get(instance.bind(py))
.unwrap_err()
.to_string(),
"BufferError: buffer contents are not compatible with u32"
);

instance.borrow_mut(py).error = Some(TestGetBufferError::IncorrectAlignment);
assert_eq!(
PyBuffer::<u32>::get_bound(instance.bind(py))
PyBuffer::<u32>::get(instance.bind(py))
.unwrap_err()
.to_string(),
"BufferError: buffer contents are insufficiently aligned for u32"
Expand Down
2 changes: 1 addition & 1 deletion tests/test_buffer_protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ fn test_buffer_referenced() {
}
.into_py(py);

let buf = PyBuffer::<u8>::get_bound(instance.bind(py)).unwrap();
let buf = PyBuffer::<u8>::get(instance.bind(py)).unwrap();
assert_eq!(buf.to_vec(py).unwrap(), input);
drop(instance);
buf
Expand Down

0 comments on commit 8446937

Please sign in to comment.