From dd047097e75ce3ad1057d453ecd504c4ffde7149 Mon Sep 17 00:00:00 2001 From: Icxolu <10486322+Icxolu@users.noreply.github.com> Date: Thu, 21 Aug 2025 18:24:47 +0200 Subject: [PATCH] add `PyErr::set_traceback` --- src/err/err_state.rs | 45 +++++++++++++++++++++++++++-------- src/err/mod.rs | 5 ++++ src/impl_/extract_argument.rs | 1 + 3 files changed, 41 insertions(+), 10 deletions(-) diff --git a/src/err/err_state.rs b/src/err/err_state.rs index 8616979310f..12ee528c7ae 100644 --- a/src/err/err_state.rs +++ b/src/err/err_state.rs @@ -4,6 +4,8 @@ use std::{ thread::ThreadId, }; +#[cfg(not(Py_3_12))] +use crate::sync::MutexExt; use crate::{ exceptions::{PyBaseException, PyTypeError}, ffi, @@ -137,7 +139,7 @@ pub(crate) struct PyErrStateNormalized { ptype: Py, pub pvalue: Py, #[cfg(not(Py_3_12))] - ptraceback: Option>, + ptraceback: std::sync::Mutex>>, } impl PyErrStateNormalized { @@ -147,10 +149,10 @@ impl PyErrStateNormalized { ptype: pvalue.get_type().into(), #[cfg(not(Py_3_12))] ptraceback: unsafe { - Py::from_owned_ptr_or_opt( + Mutex::new(Py::from_owned_ptr_or_opt( pvalue.py(), ffi::PyException_GetTraceback(pvalue.as_ptr()), - ) + )) }, pvalue: pvalue.into(), } @@ -169,6 +171,8 @@ impl PyErrStateNormalized { #[cfg(not(Py_3_12))] pub(crate) fn ptraceback<'py>(&self, py: Python<'py>) -> Option> { self.ptraceback + .lock_py_attached(py) + .unwrap() .as_ref() .map(|traceback| traceback.bind(py).clone()) } @@ -182,6 +186,21 @@ impl PyErrStateNormalized { } } + #[cfg(not(Py_3_12))] + pub(crate) fn set_ptraceback<'py>(&self, py: Python<'py>, tb: Option>) { + *self.ptraceback.lock_py_attached(py).unwrap() = tb.map(Bound::unbind); + } + + #[cfg(Py_3_12)] + pub(crate) fn set_ptraceback<'py>(&self, py: Python<'py>, tb: Option>) { + let tb = tb + .as_ref() + .map(Bound::as_ptr) + .unwrap_or_else(|| crate::types::PyNone::get(py).as_ptr()); + + unsafe { ffi::PyException_SetTraceback(self.pvalue.as_ptr(), tb) }; + } + pub(crate) fn take(py: Python<'_>) -> Option { #[cfg(Py_3_12)] { @@ -227,7 +246,7 @@ impl PyErrStateNormalized { ptype.map(|ptype| PyErrStateNormalized { ptype: ptype.unbind(), pvalue: pvalue.expect("normalized exception value missing").unbind(), - ptraceback: ptraceback.map(Bound::unbind), + ptraceback: std::sync::Mutex::new(ptraceback.map(Bound::unbind)), }) } } @@ -244,7 +263,7 @@ impl PyErrStateNormalized { pvalue: unsafe { Py::from_owned_ptr_or_opt(py, pvalue).expect("Exception value missing") }, - ptraceback: unsafe { Py::from_owned_ptr_or_opt(py, ptraceback) }, + ptraceback: unsafe { std::sync::Mutex::new(Py::from_owned_ptr_or_opt(py, ptraceback)) }, } } @@ -254,10 +273,13 @@ impl PyErrStateNormalized { ptype: self.ptype.clone_ref(py), pvalue: self.pvalue.clone_ref(py), #[cfg(not(Py_3_12))] - ptraceback: self - .ptraceback - .as_ref() - .map(|ptraceback| ptraceback.clone_ref(py)), + ptraceback: std::sync::Mutex::new( + self.ptraceback + .lock_py_attached(py) + .unwrap() + .as_ref() + .map(|ptraceback| ptraceback.clone_ref(py)), + ), } } } @@ -308,7 +330,10 @@ impl PyErrStateInner { }) => ( ptype.into_ptr(), pvalue.into_ptr(), - ptraceback.map_or(std::ptr::null_mut(), Py::into_ptr), + ptraceback + .into_inner() + .unwrap() + .map_or(std::ptr::null_mut(), Py::into_ptr), ), }; unsafe { ffi::PyErr_Restore(ptype, pvalue, ptraceback) } diff --git a/src/err/mod.rs b/src/err/mod.rs index 16265865283..0c49535eaa9 100644 --- a/src/err/mod.rs +++ b/src/err/mod.rs @@ -307,6 +307,11 @@ impl PyErr { self.normalized(py).ptraceback(py) } + /// Set the traceback associated with the exception, pass `None` to clear it. + pub fn set_traceback<'py>(&self, py: Python<'_>, tb: Option>) { + self.normalized(py).set_ptraceback(py, tb) + } + /// Gets whether an error is present in the Python interpreter's global state. #[inline] pub fn occurred(_: Python<'_>) -> bool { diff --git a/src/impl_/extract_argument.rs b/src/impl_/extract_argument.rs index 0bca69ce3af..eda13ce0d98 100644 --- a/src/impl_/extract_argument.rs +++ b/src/impl_/extract_argument.rs @@ -224,6 +224,7 @@ pub fn argument_extraction_error(py: Python<'_>, arg_name: &str, error: PyErr) - let remapped_error = PyTypeError::new_err(format!("argument '{}': {}", arg_name, error.value(py))); remapped_error.set_cause(py, error.cause(py)); + remapped_error.set_traceback(py, error.traceback(py)); remapped_error } else { error