diff --git a/newsfragments/5517.added.md b/newsfragments/5517.added.md new file mode 100644 index 00000000000..ddd89fddc18 --- /dev/null +++ b/newsfragments/5517.added.md @@ -0,0 +1 @@ +Add `BytesWriter` diff --git a/pyo3-ffi-check/build.rs b/pyo3-ffi-check/build.rs index e7cfbe40df3..ecc55f9736d 100644 --- a/pyo3-ffi-check/build.rs +++ b/pyo3-ffi-check/build.rs @@ -59,6 +59,7 @@ fn main() { .blocklist_item("FP_INT_TONEARESTFROMZERO") .blocklist_item("FP_INT_TONEAREST") .blocklist_item("FP_ZERO") + .blocklist_item("PyBytesWriter") .generate() .expect("Unable to generate bindings"); diff --git a/pyo3-ffi/src/bytesobject.rs b/pyo3-ffi/src/bytesobject.rs index 08351d18daa..5d37cdb413e 100644 --- a/pyo3-ffi/src/bytesobject.rs +++ b/pyo3-ffi/src/bytesobject.rs @@ -56,6 +56,17 @@ extern "C" { ) -> c_int; } +#[cfg(Py_3_15)] +opaque_struct!(pub PyBytesWriter); + +#[repr(C)] +#[cfg(not(Py_3_15))] +pub struct PyBytesWriter { + pub(crate) small_buffer: [c_char; 256], + pub(crate) obj: *mut PyObject, + pub(crate) size: Py_ssize_t, +} + // skipped F_LJUST // skipped F_SIGN // skipped F_BLANK diff --git a/pyo3-ffi/src/compat/mod.rs b/pyo3-ffi/src/compat/mod.rs index 044ea46762b..803db48acf5 100644 --- a/pyo3-ffi/src/compat/mod.rs +++ b/pyo3-ffi/src/compat/mod.rs @@ -53,9 +53,11 @@ macro_rules! compat_function { mod py_3_10; mod py_3_13; mod py_3_14; +mod py_3_15; mod py_3_9; pub use self::py_3_10::*; pub use self::py_3_13::*; pub use self::py_3_14::*; +pub use self::py_3_15::*; pub use self::py_3_9::*; diff --git a/pyo3-ffi/src/compat/py_3_15.rs b/pyo3-ffi/src/compat/py_3_15.rs new file mode 100644 index 00000000000..5645631d767 --- /dev/null +++ b/pyo3-ffi/src/compat/py_3_15.rs @@ -0,0 +1,226 @@ +#[cfg(not(Py_LIMITED_API))] +compat_function!( + originally_defined_for(all(Py_3_15, not(Py_LIMITED_API))); + + #[inline] + pub unsafe fn PyBytesWriter_Create( + size: crate::Py_ssize_t, + ) -> *mut crate::PyBytesWriter { + + if size < 0 { + crate::PyErr_SetString(crate::PyExc_ValueError, c_str!("size must be >= 0").as_ptr() as *const _); + return std::ptr::null_mut(); + } + + let writer: *mut crate::PyBytesWriter = crate::PyMem_Malloc(std::mem::size_of::()).cast(); + if writer.is_null() { + crate::PyErr_NoMemory(); + return std::ptr::null_mut(); + } + + (*writer).obj = std::ptr::null_mut(); + (*writer).size = 0; + + if size >=1 { + if _PyBytesWriter_Resize_impl(writer, size, 0) < 0 { + PyBytesWriter_Discard(writer); + return std::ptr::null_mut(); + } + + (*writer).size = size; + } + + writer + } +); + +#[cfg(not(Py_LIMITED_API))] +compat_function!( + originally_defined_for(all(Py_3_15, not(Py_LIMITED_API))); + + #[inline] + pub unsafe fn PyBytesWriter_Discard(writer: *mut crate::PyBytesWriter) -> () { + if writer.is_null() { + return; + } + + crate::Py_XDECREF((*writer).obj); + crate::PyMem_Free(writer.cast()); + } +); + +#[cfg(not(Py_LIMITED_API))] +compat_function!( + originally_defined_for(all(Py_3_15, not(Py_LIMITED_API))); + + #[inline] + pub unsafe fn PyBytesWriter_Finish(writer: *mut crate::PyBytesWriter) -> *mut crate::PyObject { + PyBytesWriter_FinishWithSize(writer, (*writer).size) + } +); + +#[cfg(not(Py_LIMITED_API))] +compat_function!( + originally_defined_for(all(Py_3_15, not(Py_LIMITED_API))); + + #[inline] + pub unsafe fn PyBytesWriter_FinishWithSize(writer: *mut crate::PyBytesWriter, size: crate::Py_ssize_t) -> *mut crate::PyObject { + let result = if size == 0 { + crate::PyBytes_FromStringAndSize(c_str!("").as_ptr(), 0) + } else if (*writer).obj.is_null() { + crate::PyBytes_FromStringAndSize((*writer).small_buffer.as_ptr(), size) + } else { + if size != crate::PyBytes_Size((*writer).obj) && crate::_PyBytes_Resize(&mut (*writer).obj, size) < 0{ + PyBytesWriter_Discard(writer); + return std::ptr::null_mut(); + } + std::mem::replace(&mut (*writer).obj, std::ptr::null_mut()) + }; + + PyBytesWriter_Discard(writer); + result + } +); + +#[cfg(not(Py_LIMITED_API))] +compat_function!( + originally_defined_for(all(Py_3_15, not(Py_LIMITED_API))); + + #[inline] + pub unsafe fn _PyBytesWriter_GetAllocated(writer: *mut crate::PyBytesWriter) -> crate::Py_ssize_t { + if (*writer).obj.is_null() { + std::mem::size_of_val(&(*writer).small_buffer) as _ + } else { + crate::PyBytes_Size((*writer).obj) + } + } +); + +#[cfg(not(Py_LIMITED_API))] +compat_function!( + originally_defined_for(all(Py_3_15, not(Py_LIMITED_API))); + + #[inline] + pub unsafe fn PyBytesWriter_GetData(writer: *mut crate::PyBytesWriter) -> *mut std::ffi::c_void { + if (*writer).obj.is_null() { + (*writer).small_buffer.as_ptr() as *mut _ + } else { + crate::PyBytes_AS_STRING((*writer).obj) as *mut _ + } + } +); + +#[cfg(not(Py_LIMITED_API))] +compat_function!( + originally_defined_for(all(Py_3_15, not(Py_LIMITED_API))); + + #[inline] + pub unsafe fn PyBytesWriter_GetSize(writer: *mut crate::PyBytesWriter) -> crate::Py_ssize_t { + (*writer).size + } +); + +#[cfg(not(Py_LIMITED_API))] +compat_function!( + originally_defined_for(all(Py_3_15, not(Py_LIMITED_API))); + + #[inline] + pub unsafe fn PyBytesWriter_WriteBytes(writer: *mut crate::PyBytesWriter, bytes: *const std::ffi::c_void, size: crate::Py_ssize_t) -> std::ffi::c_int { + let size = if size < 0 { + let len = libc::strlen(bytes as _); + if len > crate::PY_SSIZE_T_MAX as libc::size_t { + crate::PyErr_NoMemory(); + return -1; + } + len as crate::Py_ssize_t + } else { + size + }; + + let pos = (*writer).size; + if PyBytesWriter_Grow(writer, size) < 0 { + return -1; + } + + let buf = PyBytesWriter_GetData(writer); + std::ptr::copy_nonoverlapping(bytes, buf.add(pos as usize), size as usize); + 0 + } +); + +#[cfg(not(Py_LIMITED_API))] +compat_function!( + originally_defined_for(all(Py_3_15, not(Py_LIMITED_API))); + + #[inline] + pub unsafe fn PyBytesWriter_Grow(writer: *mut crate::PyBytesWriter, size: crate::Py_ssize_t) -> std::ffi::c_int { + if size < 0 && (*writer).size + size < 0 { + crate::PyErr_SetString(crate::PyExc_ValueError, c_str!("invalid size").as_ptr()); + return -1; + } + + if size > crate::PY_SSIZE_T_MAX - (*writer).size { + crate::PyErr_NoMemory(); + return -1; + } + let new_size = (*writer).size + size; + + if _PyBytesWriter_Resize_impl(writer, new_size, 1) < 0 { + return -1; + } + + (*writer).size = new_size; + 0 + } +); + +#[inline] +#[cfg(not(Py_LIMITED_API))] +unsafe fn _PyBytesWriter_Resize_impl( + writer: *mut crate::PyBytesWriter, + mut size: crate::Py_ssize_t, + resize: std::ffi::c_int, +) -> std::ffi::c_int { + let overallocate = resize; + assert!(size >= 0); + + if size <= _PyBytesWriter_GetAllocated(writer) { + return 0; + } + + if overallocate > 0 { + #[cfg(windows)] + if size <= (crate::PY_SSIZE_T_MAX - size / 2) { + size += size / 2; + } + + #[cfg(not(windows))] + if size <= (crate::PY_SSIZE_T_MAX - size / 4) { + size += size / 4; + } + } + + if !(*writer).obj.is_null() { + if crate::_PyBytes_Resize(&mut (*writer).obj, size) > 0 { + return -1; + } + assert!(!(*writer).obj.is_null()) + } else { + (*writer).obj = crate::PyBytes_FromStringAndSize(std::ptr::null_mut(), size); + if (*writer).obj.is_null() { + return -1; + } + + if resize > 0 { + assert!((size as usize) > std::mem::size_of_val(&(*writer).small_buffer)); + + std::ptr::copy_nonoverlapping( + (*writer).small_buffer.as_ptr(), + crate::PyBytes_AS_STRING((*writer).obj) as *mut _, + std::mem::size_of_val(&(*writer).small_buffer), + ); + } + } + + 0 +} diff --git a/src/byteswriter.rs b/src/byteswriter.rs new file mode 100644 index 00000000000..fdf17fad6ec --- /dev/null +++ b/src/byteswriter.rs @@ -0,0 +1,226 @@ +use crate::types::PyBytes; +#[cfg(not(Py_LIMITED_API))] +use crate::{ + ffi::{ + self, + compat::{ + PyBytesWriter_Create, PyBytesWriter_Discard, PyBytesWriter_Finish, + PyBytesWriter_GetData, PyBytesWriter_GetSize, PyBytesWriter_Grow, + PyBytesWriter_WriteBytes, _PyBytesWriter_GetAllocated, + }, + }, + ffi_ptr_ext::FfiPtrExt, +}; +use crate::{Bound, IntoPyObject, PyErr, PyResult, Python}; +use std::convert::Infallible; +use std::io::IoSlice; +#[cfg(not(Py_LIMITED_API))] +use std::{ + mem::ManuallyDrop, + ptr::{self, NonNull}, +}; + +pub struct PyBytesWriter<'py> { + python: Python<'py>, + #[cfg(not(Py_LIMITED_API))] + writer: NonNull, + #[cfg(Py_LIMITED_API)] + buffer: Vec, +} + +impl<'py> PyBytesWriter<'py> { + #[inline] + pub fn new(py: Python<'py>) -> PyResult { + Self::with_capacity(py, 0) + } + + #[inline] + pub fn with_capacity(py: Python<'py>, capacity: usize) -> PyResult { + #[cfg(not(Py_LIMITED_API))] + { + NonNull::new(unsafe { PyBytesWriter_Create(capacity as _) }) + .map(|writer| PyBytesWriter { python: py, writer }) + .ok_or_else(|| PyErr::fetch(py)) + } + + #[cfg(Py_LIMITED_API)] + { + Ok(PyBytesWriter { + python: py, + buffer: Vec::with_capacity(capacity), + }) + } + } + + #[inline] + pub fn capacity(&self) -> usize { + #[cfg(not(Py_LIMITED_API))] + unsafe { + _PyBytesWriter_GetAllocated(self.writer.as_ptr()) as _ + } + + #[cfg(Py_LIMITED_API)] + { + self.buffer.capacity() + } + } + + #[inline] + pub fn len(&self) -> usize { + #[cfg(not(Py_LIMITED_API))] + unsafe { + PyBytesWriter_GetSize(self.writer.as_ptr()) as _ + } + + #[cfg(Py_LIMITED_API)] + { + self.buffer.len() + } + } + + #[inline] + #[cfg(not(Py_LIMITED_API))] + fn as_mut_ptr(&mut self) -> *mut u8 { + unsafe { PyBytesWriter_GetData(self.writer.as_ptr()) as _ } + } +} + +impl<'py> From> for Bound<'py, PyBytes> { + #[inline] + #[cfg(not(Py_LIMITED_API))] + fn from(value: PyBytesWriter<'py>) -> Self { + let py = value.python; + unsafe { + PyBytesWriter_Finish(ManuallyDrop::new(value).writer.as_ptr()) + .assume_owned(py) + .cast_into_unchecked() + } + } + + #[inline] + #[cfg(Py_LIMITED_API)] + fn from(writer: PyBytesWriter<'py>) -> Self { + PyBytes::new(writer.python, &writer.buffer) + } +} + +impl<'py> IntoPyObject<'py> for PyBytesWriter<'py> { + type Target = PyBytes; + type Output = Bound<'py, PyBytes>; + type Error = Infallible; + + #[inline] + fn into_pyobject(self, _py: Python<'py>) -> Result { + Ok(self.into()) + } +} + +#[cfg(not(Py_LIMITED_API))] +impl<'py> Drop for PyBytesWriter<'py> { + #[inline] + fn drop(&mut self) { + unsafe { PyBytesWriter_Discard(self.writer.as_ptr()) } + } +} + +#[cfg(not(Py_LIMITED_API))] +impl std::io::Write for PyBytesWriter<'_> { + #[inline] + fn write(&mut self, buf: &[u8]) -> std::io::Result { + let result = unsafe { + PyBytesWriter_WriteBytes(self.writer.as_ptr(), buf.as_ptr() as _, buf.len() as _) + }; + + if result < 0 { + Err(PyErr::fetch(self.python).into()) + } else { + Ok(buf.len()) + } + } + + #[inline] + fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> std::io::Result { + let len = bufs.iter().map(|b| b.len()).sum(); + let mut pos = self.len(); + + if unsafe { PyBytesWriter_Grow(self.writer.as_ptr(), len as _) } < 0 { + return Err(PyErr::fetch(self.python).into()); + } + + for buf in bufs { + // SAFETY: We have ensured enough capacity above. + unsafe { + ptr::copy_nonoverlapping(buf.as_ptr(), self.as_mut_ptr().add(pos), buf.len()) + }; + pos += buf.len(); + } + Ok(len) + } + + #[inline] + fn flush(&mut self) -> std::io::Result<()> { + Ok(()) + } + + #[inline] + fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> { + self.write(buf)?; + Ok(()) + } +} + +#[cfg(Py_LIMITED_API)] +impl std::io::Write for PyBytesWriter<'_> { + #[inline] + fn write(&mut self, buf: &[u8]) -> std::io::Result { + self.buffer.write(buf) + } + + #[inline] + fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> std::io::Result { + self.buffer.write_vectored(bufs) + } + + #[inline] + fn flush(&mut self) -> std::io::Result<()> { + self.buffer.flush() + } + + #[inline] + fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> { + self.buffer.write_all(buf) + } + + #[inline] + fn write_fmt(&mut self, args: std::fmt::Arguments<'_>) -> std::io::Result<()> { + self.buffer.write_fmt(args) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::io::Write; + + #[test] + fn test_io_write() { + Python::attach(|py| { + let buf = [1, 2, 3, 4]; + let mut writer = PyBytesWriter::new(py).unwrap(); + assert_eq!(writer.write(&buf).unwrap(), 4); + let bytes: Bound<'_, PyBytes> = writer.into(); + assert_eq!(bytes, buf); + }) + } + + #[test] + fn test_io_write_vectored() { + Python::attach(|py| { + let bufs = [IoSlice::new(&[1, 2]), IoSlice::new(&[3, 4])]; + let mut writer = PyBytesWriter::new(py).unwrap(); + assert_eq!(writer.write_vectored(&bufs).unwrap(), 4); + let bytes: Bound<'_, PyBytes> = writer.into(); + assert_eq!(bytes, [1, 2, 3, 4]); + }) + } +} diff --git a/src/lib.rs b/src/lib.rs index ce70e11439a..261e390a0bb 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -478,6 +478,7 @@ pub mod inspect; // Putting the declaration of prelude at the end seems to help encourage rustc and rustdoc to prefer using // other paths to the same items. (e.g. `pyo3::types::PyAnyMethods` instead of `pyo3::prelude::PyAnyMethods`). +mod byteswriter; pub mod prelude; /// Test readme and user guide diff --git a/src/types/bytes.rs b/src/types/bytes.rs index c6fa1d65329..9ce41a11806 100644 --- a/src/types/bytes.rs +++ b/src/types/bytes.rs @@ -195,6 +195,12 @@ impl PartialEq<&'_ [u8]> for Bound<'_, PyBytes> { } } +impl PartialEq<[u8; N]> for Bound<'_, PyBytes> { + fn eq(&self, other: &[u8; N]) -> bool { + self.as_borrowed() == *other + } +} + /// Compares whether the Python bytes object is equal to the [u8]. /// /// In some cases Python equality might be more appropriate; see the note on [`PyBytes`]. @@ -245,6 +251,13 @@ impl PartialEq<[u8]> for Borrowed<'_, '_, PyBytes> { } } +impl PartialEq<[u8; N]> for Borrowed<'_, '_, PyBytes> { + #[inline] + fn eq(&self, other: &[u8; N]) -> bool { + self.as_bytes() == other + } +} + /// Compares whether the Python bytes object is equal to the [u8]. /// /// In some cases Python equality might be more appropriate; see the note on [`PyBytes`].