diff --git a/pyo3-polars/Cargo.toml b/pyo3-polars/Cargo.toml index 3bf5491..bf6406f 100644 --- a/pyo3-polars/Cargo.toml +++ b/pyo3-polars/Cargo.toml @@ -11,7 +11,6 @@ description = "Expression plugins and PyO3 types for polars" [dependencies] ciborium = { version = "0.2.1", optional = true } -polars = { workspace = true, default-features = false } polars-core = { workspace = true, default-features = false } polars-ffi = { workspace = true, optional = true } polars-lazy = { workspace = true, optional = true } @@ -21,6 +20,12 @@ pyo3-polars-derive = { version = "0.5.0", path = "../pyo3-polars-derive", option serde = { version = "1", optional = true } serde-pickle = { version = "1", optional = true } thiserror = "1" +once_cell = "1" +itoa = "1.0.6" + +[dependencies.polars] +workspace = true +features = ["dtype-full"] [features] lazy = ["polars/serde-lazy", "polars-plan", "polars-lazy/serde", "ciborium"] diff --git a/pyo3-polars/src/gil_once_cell.rs b/pyo3-polars/src/gil_once_cell.rs new file mode 100644 index 0000000..7a62029 --- /dev/null +++ b/pyo3-polars/src/gil_once_cell.rs @@ -0,0 +1,34 @@ +use std::cell::UnsafeCell; + +use pyo3::{PyResult, Python}; + +// Adapted from PYO3 with the only change that +// we allow mutable access with when the GIL is held + +pub struct GILOnceCell(UnsafeCell>); + +// T: Send is needed for Sync because the thread which drops the GILOnceCell can be different +// to the thread which fills it. +unsafe impl Sync for GILOnceCell {} +unsafe impl Send for GILOnceCell {} + +impl GILOnceCell { + /// Create a `GILOnceCell` which does not yet contain a value. + pub const fn new() -> Self { + Self(UnsafeCell::new(None)) + } + + /// as long as we have the GIL we can mutate + /// this creates a context that checks that. + pub fn with_gil(&self, _py: Python<'_>, mut op: F) -> PyResult + where + F: FnMut(&mut T) -> PyResult, + { + // Safe because GIL is held, so no other thread can be writing to this cell concurrently. + let inner = unsafe { &mut *self.0.get() } + .as_mut() + .expect("not yet initialized"); + + op(inner) + } +} diff --git a/pyo3-polars/src/lib.rs b/pyo3-polars/src/lib.rs index 17f49c3..ef310ff 100644 --- a/pyo3-polars/src/lib.rs +++ b/pyo3-polars/src/lib.rs @@ -1,4 +1,4 @@ -//! This crate offers a [`PySeries`] and a [`PyDataFrame`] which are simple wrapper around `Series` and `DataFrame`. The +//! This crate offers [`PySeries`], [`PyDataFrame`] and [`PyAnyValue`] which are simple wrapper around `Series`, `DataFrame` and `AnyValue`. The //! advantage of these wrappers is that they can be converted to and from python as they implement `FromPyObject` and `IntoPy`. //! //! # Example @@ -47,13 +47,24 @@ pub mod error; #[cfg(feature = "derive")] pub mod export; mod ffi; +mod gil_once_cell; +mod py_modules; +mod utils; use crate::error::PyPolarsErr; use crate::ffi::to_py::to_py_array; +use crate::py_modules::SERIES; +use crate::utils::{ + abs_decimal_from_digits, any_values_to_dtype, convert_date, convert_datetime, decimal_to_digits, +}; use polars::export::arrow; use polars::prelude::*; +use py_modules::UTILS; use pyo3::ffi::Py_uintptr_t; +use pyo3::types::{PyBool, PyDict, PyFloat, PyList, PySequence, PyString, PyTuple, PyType}; +use pyo3::{intern, PyErr}; use pyo3::{FromPyObject, IntoPy, PyAny, PyObject, PyResult, Python, ToPyObject}; +use utils::struct_dict; #[cfg(feature = "lazy")] use {polars_lazy::frame::LazyFrame, polars_plan::logical_plan::LogicalPlan}; @@ -68,6 +79,11 @@ pub struct PySeries(pub Series); /// A wrapper around a [`DataFrame`] that can be converted to and from python with `pyo3`. pub struct PyDataFrame(pub DataFrame); +#[repr(transparent)] +#[derive(Debug, Clone)] +/// A wrapper around [`AnyValue`] that can be converted to and from python with `pyo3`. +pub struct PyAnyValue<'a>(pub AnyValue<'a>); + #[cfg(feature = "lazy")] #[repr(transparent)] #[derive(Clone)] @@ -92,6 +108,18 @@ impl From for Series { } } +impl<'a> From> for AnyValue<'a> { + fn from(value: PyAnyValue<'a>) -> Self { + value.0 + } +} + +impl<'a> From> for PyAnyValue<'a> { + fn from(value: AnyValue<'a>) -> Self { + PyAnyValue(value) + } +} + #[cfg(feature = "lazy")] impl From for LazyFrame { fn from(value: PyLazyFrame) -> Self { @@ -111,6 +139,12 @@ impl AsRef for PyDataFrame { } } +impl<'a> AsRef> for PyAnyValue<'a> { + fn as_ref(&self) -> &AnyValue<'a> { + &self.0 + } +} + #[cfg(feature = "lazy")] impl AsRef for PyLazyFrame { fn as_ref(&self) -> &LazyFrame { @@ -147,6 +181,250 @@ impl<'a> FromPyObject<'a> for PyDataFrame { } } +type TypeObjectPtr = usize; +type InitFn = fn(&PyAny) -> PyResult>; +pub(crate) static LUT: crate::gil_once_cell::GILOnceCell> = + crate::gil_once_cell::GILOnceCell::new(); + +impl<'s> FromPyObject<'s> for PyAnyValue<'s> { + fn extract(ob: &'s PyAny) -> PyResult { + // conversion functions + fn get_bool(ob: &PyAny) -> PyResult> { + Ok(AnyValue::Boolean(ob.extract::().unwrap()).into()) + } + + fn get_int(ob: &PyAny) -> PyResult> { + // can overflow + match ob.extract::() { + Ok(v) => Ok(AnyValue::Int64(v).into()), + Err(_) => Ok(AnyValue::UInt64(ob.extract::()?).into()), + } + } + + fn get_float(ob: &PyAny) -> PyResult> { + Ok(AnyValue::Float64(ob.extract::().unwrap()).into()) + } + + fn get_str(ob: &PyAny) -> PyResult> { + let value = ob.extract::<&str>().unwrap(); + Ok(AnyValue::String(value).into()) + } + + fn get_struct(ob: &PyAny) -> PyResult> { + let dict = ob.downcast::().unwrap(); + let len = dict.len(); + let mut keys = Vec::with_capacity(len); + let mut vals = Vec::with_capacity(len); + for (k, v) in dict.into_iter() { + let key = k.extract::<&str>()?; + let val = v.extract::()?.0; + let dtype = DataType::from(&val); + keys.push(Field::new(key, dtype)); + vals.push(val) + } + Ok(AnyValue::StructOwned(Box::new((vals, keys))).into()) + } + + fn get_list(ob: &PyAny) -> PyResult { + fn get_list_with_constructor(ob: &PyAny) -> PyResult { + // Use the dedicated constructor + // this constructor is able to go via dedicated type constructors + // so it can be much faster + Python::with_gil(|py| { + let s = SERIES.call1(py, (ob,))?; + get_series_el(s.as_ref(py)) + }) + } + + if ob.is_empty()? { + Ok(AnyValue::List(Series::new_empty("", &DataType::Null)).into()) + } else if ob.is_instance_of::() | ob.is_instance_of::() { + let list = ob.downcast::().unwrap(); + + let mut avs = Vec::with_capacity(25); + let mut iter = list.iter()?; + + for item in (&mut iter).take(25) { + avs.push(item?.extract::()?.0) + } + + let (dtype, n_types) = any_values_to_dtype(&avs).map_err(PyPolarsErr::from)?; + + // we only take this path if there is no question of the data-type + if dtype.is_primitive() && n_types == 1 { + get_list_with_constructor(ob) + } else { + // push the rest + avs.reserve(list.len()?); + for item in iter { + avs.push(item?.extract::()?.0) + } + + let s = Series::from_any_values_and_dtype("", &avs, &dtype, true) + .map_err(PyPolarsErr::from)?; + Ok(AnyValue::List(s).into()) + } + } else { + // range will take this branch + get_list_with_constructor(ob) + } + } + + fn get_series_el(ob: &PyAny) -> PyResult> { + let py_pyseries = ob.getattr(intern!(ob.py(), "_s")).unwrap(); + let series = py_pyseries.extract::().unwrap().0; + Ok(AnyValue::List(series).into()) + } + + fn get_bin(ob: &PyAny) -> PyResult { + let value = ob.extract::<&[u8]>().unwrap(); + Ok(AnyValue::Binary(value).into()) + } + + fn get_null(_ob: &PyAny) -> PyResult { + Ok(AnyValue::Null.into()) + } + + fn get_timedelta(ob: &PyAny) -> PyResult { + Python::with_gil(|py| { + let td = UTILS + .as_ref(py) + .getattr(intern!(py, "_timedelta_to_pl_timedelta")) + .unwrap() + .call1((ob, intern!(py, "us"))) + .unwrap(); + let v = td.extract::().unwrap(); + Ok(AnyValue::Duration(v, TimeUnit::Microseconds).into()) + }) + } + + fn get_time(ob: &PyAny) -> PyResult { + Python::with_gil(|py| { + let time = UTILS + .as_ref(py) + .getattr(intern!(py, "_time_to_pl_time")) + .unwrap() + .call1((ob,)) + .unwrap(); + let v = time.extract::().unwrap(); + Ok(AnyValue::Time(v).into()) + }) + } + + fn get_decimal(ob: &PyAny) -> PyResult { + let (sign, digits, exp): (i8, Vec, i32) = ob + .call_method0(intern!(ob.py(), "as_tuple")) + .unwrap() + .extract() + .unwrap(); + // note: using Vec is not the most efficient thing here (input is a tuple) + let (mut v, scale) = abs_decimal_from_digits(digits, exp).ok_or_else(|| { + PyErr::from(PyPolarsErr::Other( + "Decimal is too large to fit in Decimal128".into(), + )) + })?; + if sign > 0 { + v = -v; // won't overflow since -i128::MAX > i128::MIN + } + Ok(AnyValue::Decimal(v, scale).into()) + } + + fn get_object(_ob: &PyAny) -> PyResult { + // TODO: need help here + // #[cfg(feature = "object")] + // { + // // this is slow, but hey don't use objects + // let v = &ObjectValue { inner: ob.into() }; + // Ok(AnyValue::ObjectOwned(OwnedObject(v.to_boxed())).into()) + // } + #[cfg(not(feature = "object"))] + { + panic!("activate object") + } + } + + // TYPE key + let type_object_ptr = PyType::as_type_ptr(ob.get_type()) as usize; + + Python::with_gil(|py| { + LUT.with_gil(py, |lut| { + // get the conversion function + let convert_fn = lut.entry(type_object_ptr).or_insert_with( + // This only runs if type is not in LUT + || { + if ob.is_instance_of::() { + get_bool + // TODO: this heap allocs on failure + } else if ob.extract::().is_ok() || ob.extract::().is_ok() { + get_int + } else if ob.is_instance_of::() { + get_float + } else if ob.is_instance_of::() { + get_str + } else if ob.is_instance_of::() { + get_struct + } else if ob.is_instance_of::() || ob.is_instance_of::() { + get_list + } else if ob.hasattr(intern!(py, "_s")).unwrap() { + get_series_el + } + // TODO: this heap allocs on failure + else if ob.extract::<&'s [u8]>().is_ok() { + get_bin + } else if ob.is_none() { + get_null + } else { + let type_name = ob.get_type().name().unwrap(); + match type_name { + "datetime" => convert_datetime, + "date" => convert_date, + "timedelta" => get_timedelta, + "time" => get_time, + "Decimal" => get_decimal, + "range" => get_list, + _ => { + // special branch for np.float as this fails isinstance float + if ob.extract::().is_ok() { + return get_float; + } + + // Can't use pyo3::types::PyDateTime with abi3-py37 feature, + // so need this workaround instead of `isinstance(ob, datetime)`. + let bases = ob + .get_type() + .getattr(intern!(py, "__bases__")) + .unwrap() + .iter() + .unwrap(); + for base in bases { + let parent_type = + base.unwrap().str().unwrap().to_str().unwrap(); + match parent_type { + "" => { + // `datetime.datetime` is a subclass of `datetime.date`, + // so need to check `datetime.datetime` first + return convert_datetime; + } + "" => { + return convert_date; + } + _ => (), + } + } + + get_object + } + } + } + }, + ); + + convert_fn(ob) + }) + }) + } +} + #[cfg(feature = "lazy")] impl<'a> FromPyObject<'a> for PyLazyFrame { fn extract(ob: &'a PyAny) -> PyResult { @@ -234,6 +512,100 @@ impl IntoPy for PyDataFrame { } } +impl IntoPy for PyAnyValue<'_> { + fn into_py(self, py: Python) -> PyObject { + let utils = UTILS.as_ref(py); + match self.0 { + AnyValue::UInt8(v) => v.into_py(py), + AnyValue::UInt16(v) => v.into_py(py), + AnyValue::UInt32(v) => v.into_py(py), + AnyValue::UInt64(v) => v.into_py(py), + AnyValue::Int8(v) => v.into_py(py), + AnyValue::Int16(v) => v.into_py(py), + AnyValue::Int32(v) => v.into_py(py), + AnyValue::Int64(v) => v.into_py(py), + AnyValue::Float32(v) => v.into_py(py), + AnyValue::Float64(v) => v.into_py(py), + AnyValue::Null => py.None(), + AnyValue::Boolean(v) => v.into_py(py), + AnyValue::String(v) => v.into_py(py), + AnyValue::StringOwned(v) => v.into_py(py), + AnyValue::Categorical(idx, rev, arr) | AnyValue::Enum(idx, rev, arr) => { + let s = if arr.is_null() { + rev.get(idx) + } else { + unsafe { arr.deref_unchecked().value(idx as usize) } + }; + s.into_py(py) + } + AnyValue::Date(v) => { + let convert = utils.getattr(intern!(py, "_to_python_date")).unwrap(); + convert.call1((v,)).unwrap().into_py(py) + } + AnyValue::Datetime(v, time_unit, time_zone) => { + let convert = utils.getattr(intern!(py, "_to_python_datetime")).unwrap(); + let time_unit = time_unit.to_ascii(); + convert + .call1((v, time_unit, time_zone.as_ref().map(|s| s.as_str()))) + .unwrap() + .into_py(py) + } + AnyValue::Duration(v, time_unit) => { + let convert = utils.getattr(intern!(py, "_to_python_timedelta")).unwrap(); + let time_unit = time_unit.to_ascii(); + convert.call1((v, time_unit)).unwrap().into_py(py) + } + AnyValue::Time(v) => { + let convert = utils.getattr(intern!(py, "_to_python_time")).unwrap(); + convert.call1((v,)).unwrap().into_py(py) + } + // TODO: need help here + AnyValue::Array(_v, _) | AnyValue::List(_v) => { + todo!(); + // PySeries(v).to_list() + } + ref av @ AnyValue::Struct(_, _, flds) => struct_dict(py, av._iter_struct_av(), flds), + AnyValue::StructOwned(payload) => struct_dict(py, payload.0.into_iter(), &payload.1), + // TODO: Also need help here + // #[cfg(feature = "object")] + // AnyValue::Object(v) => { + // let object = v.as_any().downcast_ref::().unwrap(); + // object.inner.clone() + // } + // #[cfg(feature = "object")] + // AnyValue::ObjectOwned(v) => { + // let object = v.0.as_any().downcast_ref::().unwrap(); + // object.inner.clone() + // } + AnyValue::Binary(v) => v.into_py(py), + AnyValue::BinaryOwned(v) => v.into_py(py), + AnyValue::Decimal(v, scale) => { + let convert = utils.getattr(intern!(py, "_to_python_decimal")).unwrap(); + const N: usize = 3; + let mut buf = [0_u128; N]; + let n_digits = decimal_to_digits(v.abs(), &mut buf); + let buf = unsafe { + std::slice::from_raw_parts( + buf.as_slice().as_ptr() as *const u8, + N * std::mem::size_of::(), + ) + }; + let digits = PyTuple::new(py, buf.iter().take(n_digits)); + convert + .call1((v.is_negative() as u8, digits, n_digits, -(scale as i32))) + .unwrap() + .into_py(py) + } + } + } +} + +impl ToPyObject for PyAnyValue<'_> { + fn to_object(&self, py: Python) -> PyObject { + self.clone().into_py(py) + } +} + #[cfg(feature = "lazy")] impl IntoPy for PyLazyFrame { fn into_py(self, py: Python<'_>) -> PyObject { diff --git a/pyo3-polars/src/py_modules.rs b/pyo3-polars/src/py_modules.rs new file mode 100644 index 0000000..6c7dbd2 --- /dev/null +++ b/pyo3-polars/src/py_modules.rs @@ -0,0 +1,11 @@ +use once_cell::sync::Lazy; +use pyo3::prelude::*; + +pub(crate) static POLARS: Lazy = + Lazy::new(|| Python::with_gil(|py| PyModule::import(py, "polars").unwrap().to_object(py))); + +pub(crate) static UTILS: Lazy = + Lazy::new(|| Python::with_gil(|py| POLARS.getattr(py, "utils").unwrap())); + +pub(crate) static SERIES: Lazy = + Lazy::new(|| Python::with_gil(|py| POLARS.getattr(py, "Series").unwrap())); diff --git a/pyo3-polars/src/utils.rs b/pyo3-polars/src/utils.rs new file mode 100644 index 0000000..20b52bb --- /dev/null +++ b/pyo3-polars/src/utils.rs @@ -0,0 +1,132 @@ +use crate::py_modules::UTILS; +use crate::PyAnyValue; +use polars::prelude::*; +use polars_core::utils::try_get_supertype; +use pyo3::intern; +use pyo3::types::PyDict; +use pyo3::{IntoPy, PyAny, PyObject, PyResult, Python}; + +pub(crate) fn any_values_to_dtype(column: &[AnyValue]) -> PolarsResult<(DataType, usize)> { + // we need an index-map as the order of dtypes influences how the + // struct fields are constructed. + let mut types_set = PlIndexSet::new(); + for val in column.iter() { + types_set.insert(val.into()); + } + let n_types = types_set.len(); + Ok((types_set_to_dtype(types_set)?, n_types)) +} + +fn types_set_to_dtype(types_set: PlIndexSet) -> PolarsResult { + types_set + .into_iter() + .map(Ok) + .reduce(|a, b| try_get_supertype(&a?, &b?)) + .unwrap() +} + +pub(crate) fn abs_decimal_from_digits( + digits: impl IntoIterator, + exp: i32, +) -> Option<(i128, usize)> { + const MAX_ABS_DEC: i128 = 10_i128.pow(38) - 1; + let mut v = 0_i128; + for (i, d) in digits.into_iter().map(i128::from).enumerate() { + if i < 38 { + v = v * 10 + d; + } else { + v = v.checked_mul(10).and_then(|v| v.checked_add(d))?; + } + } + // we only support non-negative scale (=> non-positive exponent) + let scale = if exp > 0 { + // the decimal may be in a non-canonical representation, try to fix it first + v = 10_i128 + .checked_pow(exp as u32) + .and_then(|factor| v.checked_mul(factor))?; + 0 + } else { + (-exp) as usize + }; + // TODO: do we care for checking if it fits in MAX_ABS_DEC? (if we set precision to None anyway?) + (v <= MAX_ABS_DEC).then_some((v, scale)) +} + +pub(crate) fn convert_date(ob: &PyAny) -> PyResult { + Python::with_gil(|py| { + let date = UTILS + .as_ref(py) + .getattr(intern!(py, "_date_to_pl_date")) + .unwrap() + .call1((ob,)) + .unwrap(); + let v = date.extract::().unwrap(); + Ok(AnyValue::Date(v).into()) + }) +} +pub(crate) fn convert_datetime(ob: &PyAny) -> PyResult { + Python::with_gil(|py| { + // windows + #[cfg(target_arch = "windows")] + let (seconds, microseconds) = { + let convert = UTILS + .getattr(py, intern!(py, "_datetime_for_any_value_windows")) + .unwrap(); + let out = convert.call1(py, (ob,)).unwrap(); + let out: (i64, i64) = out.extract(py).unwrap(); + out + }; + // unix + #[cfg(not(target_arch = "windows"))] + let (seconds, microseconds) = { + let convert = UTILS + .getattr(py, intern!(py, "_datetime_for_any_value")) + .unwrap(); + let out = convert.call1(py, (ob,)).unwrap(); + let out: (i64, i64) = out.extract(py).unwrap(); + out + }; + + // s to us + let mut v = seconds * 1_000_000; + v += microseconds; + + // choose "us" as that is python's default unit + Ok(AnyValue::Datetime(v, TimeUnit::Microseconds, &None).into()) + }) +} + +pub(crate) fn struct_dict<'a>( + py: Python, + vals: impl Iterator>, + flds: &[Field], +) -> PyObject { + let dict = PyDict::new(py); + for (fld, val) in flds.iter().zip(vals) { + dict.set_item(fld.name().as_str(), PyAnyValue(val)).unwrap() + } + dict.into_py(py) +} + +// accept u128 array to ensure alignment is correct +pub(crate) fn decimal_to_digits(v: i128, buf: &mut [u128; 3]) -> usize { + const ZEROS: i128 = 0x3030_3030_3030_3030_3030_3030_3030_3030; + // safety: transmute is safe as there are 48 bytes in 3 128bit ints + // and the minimal alignment of u8 fits u16 + let buf = unsafe { std::mem::transmute::<&mut [u128; 3], &mut [u8; 48]>(buf) }; + let mut buffer = itoa::Buffer::new(); + let value = buffer.format(v); + let len = value.len(); + for (dst, src) in buf.iter_mut().zip(value.as_bytes().iter()) { + *dst = *src + } + + let ptr = buf.as_mut_ptr() as *mut i128; + unsafe { + // this is safe because we know that the buffer is exactly 48 bytes long + *ptr -= ZEROS; + *ptr.add(1) -= ZEROS; + *ptr.add(2) -= ZEROS; + } + len +}