Skip to content

Commit

Permalink
[BUG] Fix Resource Request Serialization and factor our Serialize Obj…
Browse files Browse the repository at this point in the history
…ect as bincode (#2707)
  • Loading branch information
samster25 authored Aug 22, 2024
1 parent 9cd2151 commit 7f04f36
Show file tree
Hide file tree
Showing 27 changed files with 109 additions and 111 deletions.
6 changes: 6 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions src/common/py-serde/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
[dependencies]
bincode = {workspace = true}
pyo3 = {workspace = true, optional = true}
serde = {workspace = true}

Expand Down
3 changes: 2 additions & 1 deletion src/common/py-serde/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#[cfg(feature = "python")]
mod python;

#[cfg(feature = "python")]
pub use crate::{python::deserialize_py_object, python::serialize_py_object};

pub use bincode;
39 changes: 38 additions & 1 deletion src/common/py-serde/src/python.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
use pyo3::{PyObject, Python, ToPyObject};
#[cfg(feature = "python")]
pub use pyo3::PyObject;
#[cfg(feature = "python")]
use pyo3::{Python, ToPyObject};

use serde::{de::Error as DeError, de::Visitor, ser::Error as SerError, Deserializer, Serializer};
use std::fmt;
#[cfg(feature = "python")]

pub fn serialize_py_object<S>(obj: &PyObject, s: S) -> Result<S::Ok, S::Error>
where
Expand All @@ -15,8 +20,10 @@ where
})?;
s.serialize_bytes(bytes.as_slice())
}
#[cfg(feature = "python")]

struct PyObjectVisitor;
#[cfg(feature = "python")]

impl<'de> Visitor<'de> for PyObjectVisitor {
type Value = PyObject;
Expand Down Expand Up @@ -57,3 +64,33 @@ where
{
d.deserialize_bytes(PyObjectVisitor)
}

#[macro_export]
macro_rules! impl_bincode_py_state_serialization {
($ty:ty) => {
#[cfg(feature = "python")]
#[pymethods]
impl $ty {
pub fn __reduce__(&self, py: Python) -> PyResult<(PyObject, PyObject)> {
use pyo3::types::PyBytes;
use pyo3::PyTypeInfo;
use pyo3::ToPyObject;
Ok((
Self::type_object(py)
.getattr("_from_serialized")?
.to_object(py),
(PyBytes::new(py, &$crate::bincode::serialize(&self).unwrap()).to_object(py),)
.to_object(py),
))
}

#[staticmethod]
pub fn _from_serialized(py: Python, serialized: PyObject) -> PyResult<Self> {
use pyo3::types::PyBytes;
serialized
.extract::<&PyBytes>(py)
.map(|s| $crate::bincode::deserialize(s.as_bytes()).unwrap())
}
}
};
}
3 changes: 2 additions & 1 deletion src/common/resource-request/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
[dependencies]
common-hashable-float-wrapper = {path = "../hashable-float-wrapper"}
common-py-serde = {path = "../py-serde"}
pyo3 = {workspace = true, optional = true}
serde = {workspace = true}

[features]
python = ["dep:pyo3"]
python = ["dep:pyo3", "common-py-serde/python"]

[package]
edition = {workspace = true}
Expand Down
5 changes: 4 additions & 1 deletion src/common/resource-request/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use common_hashable_float_wrapper::FloatWrapper;
use common_py_serde::impl_bincode_py_state_serialization;
#[cfg(feature = "python")]
use pyo3::{pyclass, pyclass::CompareOp, pymethods, types::PyModule, PyResult, Python};
use pyo3::{pyclass, pyclass::CompareOp, pymethods, types::PyModule, PyObject, PyResult, Python};

use std::hash::{Hash, Hasher};
use std::ops::Add;

Expand Down Expand Up @@ -219,6 +221,7 @@ impl ResourceRequest {
Ok(format!("{:?}", self))
}
}
impl_bincode_py_state_serialization!(ResourceRequest);

#[cfg(feature = "python")]
pub fn register_modules(_py: Python, parent: &PyModule) -> PyResult<()> {
Expand Down
3 changes: 2 additions & 1 deletion src/daft-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ common-daft-config = {path = "../common/daft-config", default-features = false}
common-display = {path = "../common/display", default-features = false}
common-error = {path = "../common/error", default-features = false}
common-hashable-float-wrapper = {path = "../common/hashable-float-wrapper"}
common-py-serde = {path = "../common/py-serde", default-features = false}
daft-minhash = {path = "../daft-minhash", default-features = false}
daft-sketch = {path = "../daft-sketch", default-features = false}
fastrand = "2.1.0"
Expand Down Expand Up @@ -62,7 +63,7 @@ features = ["xxh3", "const_xxh3"]
version = "0.8.5"

[features]
python = ["dep:pyo3", "dep:numpy", "common-error/python"]
python = ["dep:pyo3", "dep:numpy", "common-error/python", "common-py-serde/python"]

[package]
edition = {workspace = true}
Expand Down
5 changes: 2 additions & 3 deletions src/daft-core/src/count_mode.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
use common_py_serde::impl_bincode_py_state_serialization;
#[cfg(feature = "python")]
use pyo3::{exceptions::PyValueError, prelude::*, types::PyBytes, PyTypeInfo};
use pyo3::{exceptions::PyValueError, prelude::*};
use serde::{Deserialize, Serialize};
use std::fmt::{Display, Formatter, Result};
use std::str::FromStr;

use crate::impl_bincode_py_state_serialization;

use common_error::{DaftError, DaftResult};

/// Supported count modes for Daft's count aggregation.
Expand Down
9 changes: 2 additions & 7 deletions src/daft-core/src/join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,10 @@ use std::{
str::FromStr,
};

use crate::impl_bincode_py_state_serialization;
use common_error::{DaftError, DaftResult};
use common_py_serde::impl_bincode_py_state_serialization;
#[cfg(feature = "python")]
use pyo3::{
exceptions::PyValueError, pyclass, pymethods, types::PyBytes, PyObject, PyResult, PyTypeInfo,
Python, ToPyObject,
};
use pyo3::{exceptions::PyValueError, pyclass, pymethods, PyObject, PyResult, Python};

use serde::{Deserialize, Serialize};

Expand Down Expand Up @@ -41,7 +38,6 @@ impl JoinType {
Ok(self.to_string())
}
}

impl_bincode_py_state_serialization!(JoinType);

impl JoinType {
Expand Down Expand Up @@ -106,7 +102,6 @@ impl JoinStrategy {
Ok(self.to_string())
}
}

impl_bincode_py_state_serialization!(JoinStrategy);

impl JoinStrategy {
Expand Down
7 changes: 4 additions & 3 deletions src/daft-core/src/python/datatype.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
use crate::{
datatypes::{DataType, Field, ImageMode, TimeUnit},
ffi, impl_bincode_py_state_serialization,
ffi,
};

use common_py_serde::impl_bincode_py_state_serialization;
use pyo3::{
class::basic::CompareOp,
exceptions::PyValueError,
prelude::*,
types::{PyBytes, PyDict, PyString},
PyTypeInfo,
types::{PyDict, PyString},
};
use serde::{Deserialize, Serialize};

Expand Down
4 changes: 2 additions & 2 deletions src/daft-core/src/python/field.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use pyo3::{prelude::*, types::PyBytes, PyTypeInfo};
use pyo3::prelude::*;
use serde::{Deserialize, Serialize};

use super::datatype::PyDataType;
use crate::datatypes;
use crate::impl_bincode_py_state_serialization;
use common_py_serde::impl_bincode_py_state_serialization;

#[pyclass(module = "daft.daft")]
#[derive(Debug, Clone, Serialize, Deserialize)]
Expand Down
4 changes: 1 addition & 3 deletions src/daft-core/src/python/schema.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
use std::sync::Arc;

use pyo3::prelude::*;
use pyo3::types::PyBytes;
use pyo3::PyTypeInfo;

use serde::{Deserialize, Serialize};

use super::datatype::PyDataType;
use super::field::PyField;
use crate::datatypes;
use crate::ffi::field_to_py;
use crate::impl_bincode_py_state_serialization;
use crate::schema;
use common_py_serde::impl_bincode_py_state_serialization;

#[pyclass(module = "daft.daft")]
#[derive(Debug, Clone, Serialize, Deserialize)]
Expand Down
31 changes: 0 additions & 31 deletions src/daft-core/src/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ pub mod display_table;
pub mod dyn_compare;
pub mod supertype;

pub use bincode;

#[macro_export]
macro_rules! impl_binary_trait_by_reference {
($ty:ty, $trait:ident, $fname:ident) => {
Expand All @@ -16,32 +14,3 @@ macro_rules! impl_binary_trait_by_reference {
}
};
}

#[macro_export]
macro_rules! impl_bincode_py_state_serialization {
($ty:ty) => {
#[cfg(feature = "python")]
#[pymethods]
impl $ty {
pub fn __reduce__(&self, py: Python) -> PyResult<(PyObject, PyObject)> {
Ok((
Self::type_object(py)
.getattr("_from_serialized")?
.to_object(py),
(
PyBytes::new(py, &$crate::utils::bincode::serialize(&self).unwrap())
.to_object(py),
)
.to_object(py),
))
}

#[staticmethod]
pub fn _from_serialized(py: Python, serialized: PyObject) -> PyResult<Self> {
serialized
.extract::<&PyBytes>(py)
.map(|s| $crate::utils::bincode::deserialize(s.as_bytes()).unwrap())
}
}
};
}
3 changes: 2 additions & 1 deletion src/daft-csv/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ async-compat = {workspace = true}
async-compression = {workspace = true}
async-stream = {workspace = true}
common-error = {path = "../common/error", default-features = false}
common-py-serde = {path = "../common/py-serde", default-features = false}
csv-async = "1.3.0"
daft-compression = {path = "../daft-compression", default-features = false}
daft-core = {path = "../daft-core", default-features = false}
Expand All @@ -24,7 +25,7 @@ url = {workspace = true}
rstest = {workspace = true}

[features]
python = ["dep:pyo3", "common-error/python", "daft-core/python", "daft-io/python", "daft-table/python", "daft-dsl/python"]
python = ["dep:pyo3", "common-error/python", "common-py-serde/python", "daft-core/python", "daft-io/python", "daft-table/python", "daft-dsl/python"]

[package]
edition = {workspace = true}
Expand Down
10 changes: 3 additions & 7 deletions src/daft-csv/src/options.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
use daft_core::{impl_bincode_py_state_serialization, schema::SchemaRef};
use common_py_serde::impl_bincode_py_state_serialization;
use daft_core::schema::SchemaRef;
use daft_dsl::ExprRef;
use serde::{Deserialize, Serialize};
#[cfg(feature = "python")]
use {
daft_core::python::schema::PySchema,
daft_dsl::python::PyExpr,
pyo3::{
pyclass, pyclass::CompareOp, pymethods, types::PyBytes, PyObject, PyResult, PyTypeInfo,
Python, ToPyObject,
},
pyo3::{pyclass, pyclass::CompareOp, pymethods, PyObject, PyResult, Python},
};

/// Options for converting CSV data to Daft data.
Expand Down Expand Up @@ -148,7 +146,6 @@ impl CsvConvertOptions {
Ok(format!("{:?}", self))
}
}

impl_bincode_py_state_serialization!(CsvConvertOptions);

/// Options for parsing CSV files.
Expand Down Expand Up @@ -374,5 +371,4 @@ impl CsvReadOptions {
Ok(format!("{:?}", self))
}
}

impl_bincode_py_state_serialization!(CsvReadOptions);
3 changes: 1 addition & 2 deletions src/daft-dsl/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use std::hash::{Hash, Hasher};
use std::sync::Arc;

use common_error::DaftError;
use common_py_serde::impl_bincode_py_state_serialization;
use common_resource_request::ResourceRequest;
use daft_core::array::ops::Utf8NormalizeOptions;
use daft_core::python::datatype::PyTimeUnit;
Expand All @@ -14,7 +15,6 @@ use crate::{functions, Expr, ExprRef, LiteralValue};
use daft_core::{
count_mode::CountMode,
datatypes::{ImageFormat, ImageMode},
impl_bincode_py_state_serialization,
python::{datatype::PyDataType, field::PyField, schema::PySchema},
};

Expand All @@ -23,7 +23,6 @@ use pyo3::{
prelude::*,
pyclass::CompareOp,
types::{PyBool, PyBytes, PyFloat, PyInt, PyString},
PyTypeInfo,
};

#[pyfunction]
Expand Down
2 changes: 2 additions & 0 deletions src/daft-json/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
arrow2 = {workspace = true, features = ["io_json"]}
chrono = {workspace = true}
common-error = {path = "../common/error", default-features = false}
common-py-serde = {path = "../common/py-serde", default-features = false}
daft-compression = {path = "../daft-compression", default-features = false}
daft-core = {path = "../daft-core", default-features = false}
daft-decoding = {path = "../daft-decoding"}
Expand Down Expand Up @@ -30,6 +31,7 @@ rstest = {workspace = true}
python = [
"dep:pyo3",
"common-error/python",
"common-py-serde/python",
"daft-core/python",
"daft-io/python",
"daft-table/python",
Expand Down
Loading

0 comments on commit 7f04f36

Please sign in to comment.