Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Use Daft Pickle instead of Ray Pickle and use bincode for serializing #2693

Merged
merged 2 commits into from
Aug 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 3 additions & 6 deletions src/common/py-serde/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@ where
S: Serializer,
{
let bytes = Python::with_gil(|py| {
py.import(pyo3::intern!(py, "ray.cloudpickle"))
.or_else(|_| py.import(pyo3::intern!(py, "pickle")))
py.import(pyo3::intern!(py, "daft.pickle"))
.and_then(|m| m.getattr(pyo3::intern!(py, "dumps")))
.and_then(|f| f.call1((obj,)))
.and_then(|b| b.extract::<Vec<u8>>())
Expand All @@ -31,8 +30,7 @@ impl<'de> Visitor<'de> for PyObjectVisitor {
E: DeError,
{
Python::with_gil(|py| {
py.import(pyo3::intern!(py, "ray.cloudpickle"))
.or_else(|_| py.import(pyo3::intern!(py, "pickle")))
py.import(pyo3::intern!(py, "daft.pickle"))
.and_then(|m| m.getattr(pyo3::intern!(py, "loads")))
.and_then(|f| Ok(f.call1((v,))?.to_object(py)))
.map_err(|e| DeError::custom(e.to_string()))
Expand All @@ -44,8 +42,7 @@ impl<'de> Visitor<'de> for PyObjectVisitor {
E: DeError,
{
Python::with_gil(|py| {
py.import(pyo3::intern!(py, "ray.cloudpickle"))
.or_else(|_| py.import(pyo3::intern!(py, "pickle")))
py.import(pyo3::intern!(py, "daft.pickle"))
.and_then(|m| m.getattr(pyo3::intern!(py, "loads")))
.and_then(|f| Ok(f.call1((v,))?.to_object(py)))
.map_err(|e| DeError::custom(e.to_string()))
Expand Down
2 changes: 1 addition & 1 deletion src/daft-dsl/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
[dependencies]
bincode = {workspace = true}
common-error = {path = "../common/error", default-features = false}
common-hashable-float-wrapper = {path = "../common/hashable-float-wrapper"}
common-py-serde = {path = "../common/py-serde", default-features = false}
Expand All @@ -11,7 +12,6 @@ itertools = {workspace = true}
log = {workspace = true}
pyo3 = {workspace = true, optional = true}
serde = {workspace = true}
serde_json = {workspace = true}
typetag = "0.2.16"

[features]
Expand Down
14 changes: 6 additions & 8 deletions src/daft-dsl/src/functions/python/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
#[cfg(feature = "python")]
mod pyobj_serde;
mod udf;

use std::sync::Arc;
Expand Down Expand Up @@ -35,7 +33,7 @@ impl PythonUDF {
pub struct StatelessPythonUDF {
pub name: Arc<String>,
#[cfg(feature = "python")]
partial_func: pyobj_serde::PyObjectWrapper,
partial_func: crate::pyobj_serde::PyObjectWrapper,
num_expressions: usize,
pub return_dtype: DataType,
pub resource_request: Option<ResourceRequest>,
Expand All @@ -46,12 +44,12 @@ pub struct StatelessPythonUDF {
pub struct StatefulPythonUDF {
pub name: Arc<String>,
#[cfg(feature = "python")]
pub stateful_partial_func: pyobj_serde::PyObjectWrapper,
pub stateful_partial_func: crate::pyobj_serde::PyObjectWrapper,
pub num_expressions: usize,
pub return_dtype: DataType,
pub resource_request: Option<ResourceRequest>,
#[cfg(feature = "python")]
pub init_args: Option<pyobj_serde::PyObjectWrapper>,
pub init_args: Option<crate::pyobj_serde::PyObjectWrapper>,
pub batch_size: Option<usize>,
pub concurrency: Option<usize>,
}
Expand All @@ -68,7 +66,7 @@ pub fn stateless_udf(
Ok(Expr::Function {
func: super::FunctionExpr::Python(PythonUDF::Stateless(StatelessPythonUDF {
name: name.to_string().into(),
partial_func: pyobj_serde::PyObjectWrapper(py_partial_stateless_udf),
partial_func: crate::pyobj_serde::PyObjectWrapper(py_partial_stateless_udf),
num_expressions: expressions.len(),
return_dtype,
resource_request,
Expand Down Expand Up @@ -113,11 +111,11 @@ pub fn stateful_udf(
Ok(Expr::Function {
func: super::FunctionExpr::Python(PythonUDF::Stateful(StatefulPythonUDF {
name: name.to_string().into(),
stateful_partial_func: pyobj_serde::PyObjectWrapper(py_stateful_partial_func),
stateful_partial_func: crate::pyobj_serde::PyObjectWrapper(py_stateful_partial_func),
num_expressions: expressions.len(),
return_dtype,
resource_request,
init_args: init_args.map(pyobj_serde::PyObjectWrapper),
init_args: init_args.map(crate::pyobj_serde::PyObjectWrapper),
batch_size,
concurrency,
})),
Expand Down
2 changes: 1 addition & 1 deletion src/daft-dsl/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ pub mod join;
mod lit;
pub mod optimization;
#[cfg(feature = "python")]
mod pyobject;
mod pyobj_serde;
#[cfg(feature = "python")]
pub mod python;
mod resolve_expr;
Expand Down
16 changes: 6 additions & 10 deletions src/daft-dsl/src/lit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use std::{
};

#[cfg(feature = "python")]
use crate::pyobject::DaftPyObject;
use crate::pyobj_serde::PyObjectWrapper;

/// Stores a literal value for queries and computations.
/// We only need to support the limited types below since those are the types that we would get from python.
Expand Down Expand Up @@ -71,7 +71,7 @@ pub enum LiteralValue {
Series(Series),
/// Python object.
#[cfg(feature = "python")]
Python(DaftPyObject),
Python(PyObjectWrapper),
}

impl Eq for LiteralValue {}
Expand Down Expand Up @@ -144,12 +144,8 @@ impl Display for LiteralValue {
#[cfg(feature = "python")]
Python(pyobj) => write!(f, "PyObject({})", {
use pyo3::prelude::*;
Python::with_gil(|py| {
pyobj
.pyobject
.call_method0(py, pyo3::intern!(py, "__str__"))
})
.unwrap()
Python::with_gil(|py| pyobj.0.call_method0(py, pyo3::intern!(py, "__str__")))
.unwrap()
}),
}
}
Expand Down Expand Up @@ -212,7 +208,7 @@ impl LiteralValue {
}
Series(series) => series.clone().rename("literal"),
#[cfg(feature = "python")]
Python(val) => PythonArray::from(("literal", vec![val.pyobject.clone()])).into_series(),
Python(val) => PythonArray::from(("literal", vec![val.0.clone()])).into_series(),
};
result
}
Expand Down Expand Up @@ -356,7 +352,7 @@ impl Literal for Series {
#[cfg(feature = "python")]
impl Literal for pyo3::PyObject {
fn lit(self) -> ExprRef {
Expr::Literal(LiteralValue::Python(DaftPyObject { pyobject: self })).into()
Expr::Literal(LiteralValue::Python(PyObjectWrapper(self))).into()
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use std::hash::{Hash, Hasher};
use std::{
hash::{Hash, Hasher},
io::Write,
};

use common_py_serde::{deserialize_py_object, serialize_py_object};
use pyo3::{PyObject, Python};
Expand Down Expand Up @@ -29,7 +32,25 @@ impl Hash for PyObjectWrapper {
// If Python object is hashable, hash the Python-side hash.
Ok(py_obj_hash) => py_obj_hash.hash(state),
// Fall back to hashing the pickled Python object.
Err(_) => serde_json::to_vec(self).unwrap().hash(state),
Err(_) => {
let hasher = HashWriter { state };
bincode::serialize_into(hasher, self)
.expect("Pickling error occurred when computing hash of Pyobject")
}
}
}
}

struct HashWriter<'a, H: Hasher> {
state: &'a mut H,
}

impl<'a, H: Hasher> Write for HashWriter<'a, H> {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
buf.hash(self.state);
Ok(buf.len())
}
fn flush(&mut self) -> std::io::Result<()> {
Ok(())
}
}
89 changes: 0 additions & 89 deletions src/daft-dsl/src/pyobject.rs

This file was deleted.

Loading