diff --git a/Cargo.lock b/Cargo.lock index 00dcccea8c..e7902a5061 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1290,6 +1290,7 @@ dependencies = [ name = "common-py-serde" version = "0.3.0-dev0" dependencies = [ + "bincode", "pyo3", "serde", ] @@ -1299,6 +1300,7 @@ name = "common-resource-request" version = "0.3.0-dev0" dependencies = [ "common-hashable-float-wrapper", + "common-py-serde", "pyo3", "serde", ] @@ -1640,6 +1642,7 @@ dependencies = [ "common-display", "common-error", "common-hashable-float-wrapper", + "common-py-serde", "daft-minhash", "daft-sketch", "fastrand 2.1.0", @@ -1677,6 +1680,7 @@ dependencies = [ "async-compression", "async-stream", "common-error", + "common-py-serde", "csv-async", "daft-compression", "daft-core", @@ -1799,6 +1803,7 @@ dependencies = [ "arrow2", "chrono", "common-error", + "common-py-serde", "daft-compression", "daft-core", "daft-decoding", @@ -1989,6 +1994,7 @@ dependencies = [ "common-display", "common-error", "common-io-config", + "common-py-serde", "daft-core", "daft-dsl", "daft-plan", diff --git a/src/common/py-serde/Cargo.toml b/src/common/py-serde/Cargo.toml index bb00e8c9b9..e799b0a574 100644 --- a/src/common/py-serde/Cargo.toml +++ b/src/common/py-serde/Cargo.toml @@ -1,4 +1,5 @@ [dependencies] +bincode = {workspace = true} pyo3 = {workspace = true, optional = true} serde = {workspace = true} diff --git a/src/common/py-serde/src/lib.rs b/src/common/py-serde/src/lib.rs index bb14399064..af7c0cf7ed 100644 --- a/src/common/py-serde/src/lib.rs +++ b/src/common/py-serde/src/lib.rs @@ -1,5 +1,6 @@ -#[cfg(feature = "python")] mod python; #[cfg(feature = "python")] pub use crate::{python::deserialize_py_object, python::serialize_py_object}; + +pub use bincode; diff --git a/src/common/py-serde/src/python.rs b/src/common/py-serde/src/python.rs index 8e12abfa00..aa505f4187 100644 --- a/src/common/py-serde/src/python.rs +++ b/src/common/py-serde/src/python.rs @@ -1,6 +1,11 @@ -use pyo3::{PyObject, Python, ToPyObject}; +#[cfg(feature = "python")] +pub use pyo3::PyObject; +#[cfg(feature = "python")] +use pyo3::{Python, ToPyObject}; + use serde::{de::Error as DeError, de::Visitor, ser::Error as SerError, Deserializer, Serializer}; use std::fmt; +#[cfg(feature = "python")] pub fn serialize_py_object(obj: &PyObject, s: S) -> Result where @@ -15,8 +20,10 @@ where })?; s.serialize_bytes(bytes.as_slice()) } +#[cfg(feature = "python")] struct PyObjectVisitor; +#[cfg(feature = "python")] impl<'de> Visitor<'de> for PyObjectVisitor { type Value = PyObject; @@ -57,3 +64,33 @@ where { d.deserialize_bytes(PyObjectVisitor) } + +#[macro_export] +macro_rules! impl_bincode_py_state_serialization { + ($ty:ty) => { + #[cfg(feature = "python")] + #[pymethods] + impl $ty { + pub fn __reduce__(&self, py: Python) -> PyResult<(PyObject, PyObject)> { + use pyo3::types::PyBytes; + use pyo3::PyTypeInfo; + use pyo3::ToPyObject; + Ok(( + Self::type_object(py) + .getattr("_from_serialized")? + .to_object(py), + (PyBytes::new(py, &$crate::bincode::serialize(&self).unwrap()).to_object(py),) + .to_object(py), + )) + } + + #[staticmethod] + pub fn _from_serialized(py: Python, serialized: PyObject) -> PyResult { + use pyo3::types::PyBytes; + serialized + .extract::<&PyBytes>(py) + .map(|s| $crate::bincode::deserialize(s.as_bytes()).unwrap()) + } + } + }; +} diff --git a/src/common/resource-request/Cargo.toml b/src/common/resource-request/Cargo.toml index 12244b9283..a2db514585 100644 --- a/src/common/resource-request/Cargo.toml +++ b/src/common/resource-request/Cargo.toml @@ -1,10 +1,11 @@ [dependencies] common-hashable-float-wrapper = {path = "../hashable-float-wrapper"} +common-py-serde = {path = "../py-serde"} pyo3 = {workspace = true, optional = true} serde = {workspace = true} [features] -python = ["dep:pyo3"] +python = ["dep:pyo3", "common-py-serde/python"] [package] edition = {workspace = true} diff --git a/src/common/resource-request/src/lib.rs b/src/common/resource-request/src/lib.rs index 10d8801a78..31e2f66a65 100644 --- a/src/common/resource-request/src/lib.rs +++ b/src/common/resource-request/src/lib.rs @@ -1,6 +1,8 @@ use common_hashable_float_wrapper::FloatWrapper; +use common_py_serde::impl_bincode_py_state_serialization; #[cfg(feature = "python")] -use pyo3::{pyclass, pyclass::CompareOp, pymethods, types::PyModule, PyResult, Python}; +use pyo3::{pyclass, pyclass::CompareOp, pymethods, types::PyModule, PyObject, PyResult, Python}; + use std::hash::{Hash, Hasher}; use std::ops::Add; @@ -219,6 +221,7 @@ impl ResourceRequest { Ok(format!("{:?}", self)) } } +impl_bincode_py_state_serialization!(ResourceRequest); #[cfg(feature = "python")] pub fn register_modules(_py: Python, parent: &PyModule) -> PyResult<()> { diff --git a/src/daft-core/Cargo.toml b/src/daft-core/Cargo.toml index 80bcc303b2..e0e8fe6ec2 100644 --- a/src/daft-core/Cargo.toml +++ b/src/daft-core/Cargo.toml @@ -24,6 +24,7 @@ common-daft-config = {path = "../common/daft-config", default-features = false} common-display = {path = "../common/display", default-features = false} common-error = {path = "../common/error", default-features = false} common-hashable-float-wrapper = {path = "../common/hashable-float-wrapper"} +common-py-serde = {path = "../common/py-serde", default-features = false} daft-minhash = {path = "../daft-minhash", default-features = false} daft-sketch = {path = "../daft-sketch", default-features = false} fastrand = "2.1.0" @@ -62,7 +63,7 @@ features = ["xxh3", "const_xxh3"] version = "0.8.5" [features] -python = ["dep:pyo3", "dep:numpy", "common-error/python"] +python = ["dep:pyo3", "dep:numpy", "common-error/python", "common-py-serde/python"] [package] edition = {workspace = true} diff --git a/src/daft-core/src/count_mode.rs b/src/daft-core/src/count_mode.rs index dbfcd30ea8..43e5fbc6e9 100644 --- a/src/daft-core/src/count_mode.rs +++ b/src/daft-core/src/count_mode.rs @@ -1,11 +1,10 @@ +use common_py_serde::impl_bincode_py_state_serialization; #[cfg(feature = "python")] -use pyo3::{exceptions::PyValueError, prelude::*, types::PyBytes, PyTypeInfo}; +use pyo3::{exceptions::PyValueError, prelude::*}; use serde::{Deserialize, Serialize}; use std::fmt::{Display, Formatter, Result}; use std::str::FromStr; -use crate::impl_bincode_py_state_serialization; - use common_error::{DaftError, DaftResult}; /// Supported count modes for Daft's count aggregation. diff --git a/src/daft-core/src/join.rs b/src/daft-core/src/join.rs index 018fba15b9..33600cd32d 100644 --- a/src/daft-core/src/join.rs +++ b/src/daft-core/src/join.rs @@ -3,13 +3,10 @@ use std::{ str::FromStr, }; -use crate::impl_bincode_py_state_serialization; use common_error::{DaftError, DaftResult}; +use common_py_serde::impl_bincode_py_state_serialization; #[cfg(feature = "python")] -use pyo3::{ - exceptions::PyValueError, pyclass, pymethods, types::PyBytes, PyObject, PyResult, PyTypeInfo, - Python, ToPyObject, -}; +use pyo3::{exceptions::PyValueError, pyclass, pymethods, PyObject, PyResult, Python}; use serde::{Deserialize, Serialize}; @@ -41,7 +38,6 @@ impl JoinType { Ok(self.to_string()) } } - impl_bincode_py_state_serialization!(JoinType); impl JoinType { @@ -106,7 +102,6 @@ impl JoinStrategy { Ok(self.to_string()) } } - impl_bincode_py_state_serialization!(JoinStrategy); impl JoinStrategy { diff --git a/src/daft-core/src/python/datatype.rs b/src/daft-core/src/python/datatype.rs index 398a4fdf9e..082605b739 100644 --- a/src/daft-core/src/python/datatype.rs +++ b/src/daft-core/src/python/datatype.rs @@ -1,13 +1,14 @@ use crate::{ datatypes::{DataType, Field, ImageMode, TimeUnit}, - ffi, impl_bincode_py_state_serialization, + ffi, }; + +use common_py_serde::impl_bincode_py_state_serialization; use pyo3::{ class::basic::CompareOp, exceptions::PyValueError, prelude::*, - types::{PyBytes, PyDict, PyString}, - PyTypeInfo, + types::{PyDict, PyString}, }; use serde::{Deserialize, Serialize}; diff --git a/src/daft-core/src/python/field.rs b/src/daft-core/src/python/field.rs index 6529edd863..e6e54491aa 100644 --- a/src/daft-core/src/python/field.rs +++ b/src/daft-core/src/python/field.rs @@ -1,9 +1,9 @@ -use pyo3::{prelude::*, types::PyBytes, PyTypeInfo}; +use pyo3::prelude::*; use serde::{Deserialize, Serialize}; use super::datatype::PyDataType; use crate::datatypes; -use crate::impl_bincode_py_state_serialization; +use common_py_serde::impl_bincode_py_state_serialization; #[pyclass(module = "daft.daft")] #[derive(Debug, Clone, Serialize, Deserialize)] diff --git a/src/daft-core/src/python/schema.rs b/src/daft-core/src/python/schema.rs index 5d6cce9ca4..33fce18df0 100644 --- a/src/daft-core/src/python/schema.rs +++ b/src/daft-core/src/python/schema.rs @@ -1,8 +1,6 @@ use std::sync::Arc; use pyo3::prelude::*; -use pyo3::types::PyBytes; -use pyo3::PyTypeInfo; use serde::{Deserialize, Serialize}; @@ -10,8 +8,8 @@ use super::datatype::PyDataType; use super::field::PyField; use crate::datatypes; use crate::ffi::field_to_py; -use crate::impl_bincode_py_state_serialization; use crate::schema; +use common_py_serde::impl_bincode_py_state_serialization; #[pyclass(module = "daft.daft")] #[derive(Debug, Clone, Serialize, Deserialize)] diff --git a/src/daft-core/src/utils/mod.rs b/src/daft-core/src/utils/mod.rs index 239f30af79..ba9164f8ad 100644 --- a/src/daft-core/src/utils/mod.rs +++ b/src/daft-core/src/utils/mod.rs @@ -3,8 +3,6 @@ pub mod display_table; pub mod dyn_compare; pub mod supertype; -pub use bincode; - #[macro_export] macro_rules! impl_binary_trait_by_reference { ($ty:ty, $trait:ident, $fname:ident) => { @@ -16,32 +14,3 @@ macro_rules! impl_binary_trait_by_reference { } }; } - -#[macro_export] -macro_rules! impl_bincode_py_state_serialization { - ($ty:ty) => { - #[cfg(feature = "python")] - #[pymethods] - impl $ty { - pub fn __reduce__(&self, py: Python) -> PyResult<(PyObject, PyObject)> { - Ok(( - Self::type_object(py) - .getattr("_from_serialized")? - .to_object(py), - ( - PyBytes::new(py, &$crate::utils::bincode::serialize(&self).unwrap()) - .to_object(py), - ) - .to_object(py), - )) - } - - #[staticmethod] - pub fn _from_serialized(py: Python, serialized: PyObject) -> PyResult { - serialized - .extract::<&PyBytes>(py) - .map(|s| $crate::utils::bincode::deserialize(s.as_bytes()).unwrap()) - } - } - }; -} diff --git a/src/daft-csv/Cargo.toml b/src/daft-csv/Cargo.toml index 9e897af517..4e58223843 100644 --- a/src/daft-csv/Cargo.toml +++ b/src/daft-csv/Cargo.toml @@ -4,6 +4,7 @@ async-compat = {workspace = true} async-compression = {workspace = true} async-stream = {workspace = true} common-error = {path = "../common/error", default-features = false} +common-py-serde = {path = "../common/py-serde", default-features = false} csv-async = "1.3.0" daft-compression = {path = "../daft-compression", default-features = false} daft-core = {path = "../daft-core", default-features = false} @@ -24,7 +25,7 @@ url = {workspace = true} rstest = {workspace = true} [features] -python = ["dep:pyo3", "common-error/python", "daft-core/python", "daft-io/python", "daft-table/python", "daft-dsl/python"] +python = ["dep:pyo3", "common-error/python", "common-py-serde/python", "daft-core/python", "daft-io/python", "daft-table/python", "daft-dsl/python"] [package] edition = {workspace = true} diff --git a/src/daft-csv/src/options.rs b/src/daft-csv/src/options.rs index 15349f36f3..17106035f9 100644 --- a/src/daft-csv/src/options.rs +++ b/src/daft-csv/src/options.rs @@ -1,14 +1,12 @@ -use daft_core::{impl_bincode_py_state_serialization, schema::SchemaRef}; +use common_py_serde::impl_bincode_py_state_serialization; +use daft_core::schema::SchemaRef; use daft_dsl::ExprRef; use serde::{Deserialize, Serialize}; #[cfg(feature = "python")] use { daft_core::python::schema::PySchema, daft_dsl::python::PyExpr, - pyo3::{ - pyclass, pyclass::CompareOp, pymethods, types::PyBytes, PyObject, PyResult, PyTypeInfo, - Python, ToPyObject, - }, + pyo3::{pyclass, pyclass::CompareOp, pymethods, PyObject, PyResult, Python}, }; /// Options for converting CSV data to Daft data. @@ -148,7 +146,6 @@ impl CsvConvertOptions { Ok(format!("{:?}", self)) } } - impl_bincode_py_state_serialization!(CsvConvertOptions); /// Options for parsing CSV files. @@ -374,5 +371,4 @@ impl CsvReadOptions { Ok(format!("{:?}", self)) } } - impl_bincode_py_state_serialization!(CsvReadOptions); diff --git a/src/daft-dsl/src/python.rs b/src/daft-dsl/src/python.rs index 4217ed68c0..0e1eae035c 100644 --- a/src/daft-dsl/src/python.rs +++ b/src/daft-dsl/src/python.rs @@ -4,6 +4,7 @@ use std::hash::{Hash, Hasher}; use std::sync::Arc; use common_error::DaftError; +use common_py_serde::impl_bincode_py_state_serialization; use common_resource_request::ResourceRequest; use daft_core::array::ops::Utf8NormalizeOptions; use daft_core::python::datatype::PyTimeUnit; @@ -14,7 +15,6 @@ use crate::{functions, Expr, ExprRef, LiteralValue}; use daft_core::{ count_mode::CountMode, datatypes::{ImageFormat, ImageMode}, - impl_bincode_py_state_serialization, python::{datatype::PyDataType, field::PyField, schema::PySchema}, }; @@ -23,7 +23,6 @@ use pyo3::{ prelude::*, pyclass::CompareOp, types::{PyBool, PyBytes, PyFloat, PyInt, PyString}, - PyTypeInfo, }; #[pyfunction] diff --git a/src/daft-json/Cargo.toml b/src/daft-json/Cargo.toml index 49d0690a92..f28c6bbead 100644 --- a/src/daft-json/Cargo.toml +++ b/src/daft-json/Cargo.toml @@ -2,6 +2,7 @@ arrow2 = {workspace = true, features = ["io_json"]} chrono = {workspace = true} common-error = {path = "../common/error", default-features = false} +common-py-serde = {path = "../common/py-serde", default-features = false} daft-compression = {path = "../daft-compression", default-features = false} daft-core = {path = "../daft-core", default-features = false} daft-decoding = {path = "../daft-decoding"} @@ -30,6 +31,7 @@ rstest = {workspace = true} python = [ "dep:pyo3", "common-error/python", + "common-py-serde/python", "daft-core/python", "daft-io/python", "daft-table/python", diff --git a/src/daft-json/src/options.rs b/src/daft-json/src/options.rs index 625fd837e3..feeae7559a 100644 --- a/src/daft-json/src/options.rs +++ b/src/daft-json/src/options.rs @@ -1,14 +1,12 @@ -use daft_core::{impl_bincode_py_state_serialization, schema::SchemaRef}; +use common_py_serde::impl_bincode_py_state_serialization; +use daft_core::schema::SchemaRef; use daft_dsl::ExprRef; use serde::{Deserialize, Serialize}; #[cfg(feature = "python")] use { daft_core::python::schema::PySchema, daft_dsl::python::PyExpr, - pyo3::{ - pyclass, pyclass::CompareOp, pymethods, types::PyBytes, PyObject, PyResult, PyTypeInfo, - Python, ToPyObject, - }, + pyo3::{pyclass, pyclass::CompareOp, pymethods, PyObject, PyResult, Python}, }; /// Options for converting JSON data to Daft data. @@ -117,7 +115,6 @@ impl JsonConvertOptions { Ok(format!("{:?}", self)) } } - impl_bincode_py_state_serialization!(JsonConvertOptions); /// Options for parsing JSON files. @@ -160,7 +157,6 @@ impl JsonParseOptions { Ok(format!("{:?}", self)) } } - impl_bincode_py_state_serialization!(JsonParseOptions); /// Options for reading JSON files. @@ -227,5 +223,4 @@ impl JsonReadOptions { Ok(format!("{:?}", self)) } } - impl_bincode_py_state_serialization!(JsonReadOptions); diff --git a/src/daft-plan/src/source_info/file_info.rs b/src/daft-plan/src/source_info/file_info.rs index aff39189aa..36c84dc6c8 100644 --- a/src/daft-plan/src/source_info/file_info.rs +++ b/src/daft-plan/src/source_info/file_info.rs @@ -1,15 +1,13 @@ use arrow2::array::Array; use common_error::DaftResult; -use daft_core::{impl_bincode_py_state_serialization, Series}; +use common_py_serde::impl_bincode_py_state_serialization; +use daft_core::Series; use daft_table::Table; use serde::{Deserialize, Serialize}; #[cfg(feature = "python")] use { daft_table::python::PyTable, - pyo3::{ - exceptions::PyKeyError, pyclass, pymethods, types::PyBytes, PyObject, PyResult, PyTypeInfo, - Python, ToPyObject, - }, + pyo3::{exceptions::PyKeyError, pyclass, pymethods, PyObject, PyResult, Python}, }; /// Metadata for a single file. diff --git a/src/daft-scan/src/file_format.rs b/src/daft-scan/src/file_format.rs index ef0954facd..90ccb708b6 100644 --- a/src/daft-scan/src/file_format.rs +++ b/src/daft-scan/src/file_format.rs @@ -1,19 +1,16 @@ use common_error::{DaftError, DaftResult}; -use daft_core::{ - datatypes::{Field, TimeUnit}, - impl_bincode_py_state_serialization, -}; +use daft_core::datatypes::{Field, TimeUnit}; use serde::{Deserialize, Serialize}; use std::hash::Hash; use std::{collections::BTreeMap, str::FromStr, sync::Arc}; + +use common_py_serde::impl_bincode_py_state_serialization; + #[cfg(feature = "python")] use { common_py_serde::{deserialize_py_object, serialize_py_object}, daft_core::python::{datatype::PyTimeUnit, field::PyField}, - pyo3::{ - pyclass, pyclass::CompareOp, pymethods, types::PyBytes, IntoPy, PyObject, PyResult, - PyTypeInfo, Python, ToPyObject, - }, + pyo3::{pyclass, pyclass::CompareOp, pymethods, IntoPy, PyObject, PyResult, Python}, }; /// Format of a file, e.g. Parquet, CSV, JSON. diff --git a/src/daft-scan/src/lib.rs b/src/daft-scan/src/lib.rs index 101d77c59d..f2db1eae05 100644 --- a/src/daft-scan/src/lib.rs +++ b/src/daft-scan/src/lib.rs @@ -939,7 +939,6 @@ mod test { fn make_scan_task(num_sources: usize) -> ScanTask { let sources = (0..num_sources) - .into_iter() .map(|i| DataSource::File { path: format!("test{}", i), chunk_spec: None, diff --git a/src/daft-scan/src/python.rs b/src/daft-scan/src/python.rs index 9cb834078c..5652ba33be 100644 --- a/src/daft-scan/src/python.rs +++ b/src/daft-scan/src/python.rs @@ -40,22 +40,20 @@ impl PartialEq for PythonTablesFactoryArgs { pub mod pylib { use common_error::DaftResult; + use common_py_serde::impl_bincode_py_state_serialization; use daft_core::python::field::PyField; use daft_core::schema::SchemaRef; use daft_dsl::python::PyExpr; - use daft_core::impl_bincode_py_state_serialization; use daft_stats::PartitionSpec; use daft_stats::TableMetadata; use daft_stats::TableStatistics; use daft_table::python::PyTable; use daft_table::Table; use pyo3::prelude::*; - use pyo3::types::PyBytes; + use pyo3::types::PyIterator; use pyo3::types::PyList; - use pyo3::PyTypeInfo; - use std::sync::Arc; use daft_core::python::schema::PySchema; diff --git a/src/daft-scan/src/storage_config.rs b/src/daft-scan/src/storage_config.rs index 640861481e..7f1ffc6661 100644 --- a/src/daft-scan/src/storage_config.rs +++ b/src/daft-scan/src/storage_config.rs @@ -2,17 +2,14 @@ use std::sync::Arc; use common_error::DaftResult; use common_io_config::IOConfig; -use daft_core::impl_bincode_py_state_serialization; +use common_py_serde::impl_bincode_py_state_serialization; use daft_io::{get_io_client, get_runtime, IOClient}; use serde::{Deserialize, Serialize}; #[cfg(feature = "python")] use { common_io_config::python, - pyo3::{ - pyclass, pymethods, types::PyBytes, IntoPy, PyObject, PyResult, PyTypeInfo, Python, - ToPyObject, - }, + pyo3::{pyclass, pymethods, IntoPy, PyObject, PyResult, Python}, std::hash::{Hash, Hasher}, }; diff --git a/src/daft-scheduler/Cargo.toml b/src/daft-scheduler/Cargo.toml index e1b238fc7d..8ba0fafe9f 100644 --- a/src/daft-scheduler/Cargo.toml +++ b/src/daft-scheduler/Cargo.toml @@ -3,6 +3,7 @@ common-daft-config = {path = "../common/daft-config", default-features = false} common-display = {path = "../common/display", default-features = false} common-error = {path = "../common/error", default-features = false} common-io-config = {path = "../common/io-config", default-features = false} +common-py-serde = {path = "../common/py-serde", default-features = false} daft-core = {path = "../daft-core", default-features = false} daft-dsl = {path = "../daft-dsl", default-features = false} daft-plan = {path = "../daft-plan", default-features = false} @@ -20,6 +21,7 @@ python = [ "common-error/python", "common-io-config/python", "common-daft-config/python", + "common-py-serde/python", "daft-core/python", "daft-plan/python", "daft-dsl/python" diff --git a/src/daft-scheduler/src/scheduler.rs b/src/daft-scheduler/src/scheduler.rs index 6f85d8612a..4aaa842e6e 100644 --- a/src/daft-scheduler/src/scheduler.rs +++ b/src/daft-scheduler/src/scheduler.rs @@ -1,5 +1,6 @@ use common_display::mermaid::MermaidDisplayOptions; use common_error::DaftResult; +use common_py_serde::impl_bincode_py_state_serialization; use daft_plan::{logical_to_physical, PhysicalPlan, PhysicalPlanRef, QueryStageOutput}; use serde::{Deserialize, Serialize}; @@ -14,14 +15,10 @@ use { daft_dsl::Expr, daft_plan::{OutputFileInfo, PyLogicalPlanBuilder}, daft_scan::{file_format::FileFormat, python::pylib::PyScanTask}, - pyo3::{ - pyclass, pymethods, types::PyBytes, PyObject, PyRef, PyRefMut, PyResult, PyTypeInfo, - Python, ToPyObject, - }, + pyo3::{pyclass, pymethods, PyObject, PyRef, PyRefMut, PyResult, Python}, std::collections::HashMap, }; -use daft_core::impl_bincode_py_state_serialization; use daft_dsl::ExprRef; use daft_plan::InMemoryInfo; use std::sync::Arc; diff --git a/src/parquet2/src/lib.rs b/src/parquet2/src/lib.rs index c8da98621d..db04478f99 100644 --- a/src/parquet2/src/lib.rs +++ b/src/parquet2/src/lib.rs @@ -30,14 +30,3 @@ const PARQUET_MAGIC: [u8; 4] = [b'P', b'A', b'R', b'1']; /// The number of bytes read at the end of the parquet file on first read const DEFAULT_FOOTER_READ_SIZE: u64 = 64 * 1024; - -#[cfg(test)] -mod tests { - use std::path::PathBuf; - - pub fn get_path() -> PathBuf { - let dir = env!("CARGO_MANIFEST_DIR"); - - PathBuf::from(dir).join("testing/parquet-testing/data") - } -} diff --git a/tests/test_resource_requests.py b/tests/test_resource_requests.py index 13ccf0cb74..9255b9163b 100644 --- a/tests/test_resource_requests.py +++ b/tests/test_resource_requests.py @@ -1,5 +1,6 @@ from __future__ import annotations +import copy import os import pytest @@ -52,6 +53,21 @@ def test_partial_resource_request_overrides(): assert new_udf.resource_request.memory_bytes == 100 +def test_resource_request_pickle_roundtrip(): + new_udf = my_udf.override_options(num_cpus=1.0) + assert new_udf.resource_request.num_cpus == 1.0 + assert new_udf.resource_request.num_gpus is None + assert new_udf.resource_request.memory_bytes is None + + assert new_udf == copy.deepcopy(new_udf) + + new_udf = new_udf.override_options(num_gpus=8.0) + assert new_udf.resource_request.num_cpus == 1.0 + assert new_udf.resource_request.num_gpus == 8.0 + assert new_udf.resource_request.memory_bytes is None + assert new_udf == copy.deepcopy(new_udf) + + ### # Assert PyRunner behavior for GPU requests: # Fail if requesting more GPUs than is available, but otherwise we do not modify anything