diff --git a/daft/dataframe/dataframe.py b/daft/dataframe/dataframe.py index 874199ed8f..c32e69ebe5 100644 --- a/daft/dataframe/dataframe.py +++ b/daft/dataframe/dataframe.py @@ -65,9 +65,6 @@ ManyColumnsInputType = Union[ColumnInputType, Iterable[ColumnInputType]] -NUM_CPUS = multiprocessing.cpu_count() - - class DataFrame: """A Daft DataFrame is a table of data. It has columns, where each column has a type and the same number of items (rows) as all other columns. @@ -226,7 +223,9 @@ def __iter__(self) -> Iterator[Dict[str, Any]]: return self.iter_rows(results_buffer_size=None) @DataframePublicAPI - def iter_rows(self, results_buffer_size: Optional[int] = NUM_CPUS) -> Iterator[Dict[str, Any]]: + def iter_rows( + self, results_buffer_size: Union[Optional[int], Literal["num_cpus"]] = "num_cpus" + ) -> Iterator[Dict[str, Any]]: """Return an iterator of rows for this dataframe. Each row will be a Python dictionary of the form { "key" : value, ... }. If you are instead looking to iterate over @@ -263,6 +262,9 @@ def iter_rows(self, results_buffer_size: Optional[int] = NUM_CPUS) -> Iterator[D .. seealso:: :meth:`df.iter_partitions() `: iterator over entire partitions instead of single rows """ + if results_buffer_size == "num_cpus": + results_buffer_size = multiprocessing.cpu_count() + if self._result is not None: # If the dataframe has already finished executing, # use the precomputed results. @@ -270,7 +272,6 @@ def iter_rows(self, results_buffer_size: Optional[int] = NUM_CPUS) -> Iterator[D for i in range(len(self)): row = {key: value[i] for (key, value) in pydict.items()} yield row - else: # Execute the dataframe in a streaming fashion. context = get_context() @@ -286,21 +287,32 @@ def iter_rows(self, results_buffer_size: Optional[int] = NUM_CPUS) -> Iterator[D yield row @DataframePublicAPI - def to_arrow_iter(self, results_buffer_size: Optional[int] = 1) -> Iterator["pyarrow.RecordBatch"]: + def to_arrow_iter( + self, + results_buffer_size: Union[Optional[int], Literal["num_cpus"]] = "num_cpus", + ) -> Iterator["pyarrow.RecordBatch"]: """ Return an iterator of pyarrow recordbatches for this dataframe. """ + for name in self.schema().column_names(): + if self.schema()[name].dtype._is_python_type(): + raise ValueError( + f"Cannot convert column {name} to Arrow type, found Python type: {self.schema()[name].dtype}" + ) + + if results_buffer_size == "num_cpus": + results_buffer_size = multiprocessing.cpu_count() if results_buffer_size is not None and not results_buffer_size > 0: raise ValueError(f"Provided `results_buffer_size` value must be > 0, received: {results_buffer_size}") if self._result is not None: # If the dataframe has already finished executing, # use the precomputed results. - yield from self.to_arrow().to_batches() - + for _, result in self._result.items(): + yield from (result.micropartition().to_arrow().to_batches()) else: # Execute the dataframe in a streaming fashion. context = get_context() - partitions_iter = context.runner().run_iter_tables(self._builder, results_buffer_size) + partitions_iter = context.runner().run_iter_tables(self._builder, results_buffer_size=results_buffer_size) # Iterate through partitions. for partition in partitions_iter: @@ -308,7 +320,7 @@ def to_arrow_iter(self, results_buffer_size: Optional[int] = 1) -> Iterator["pya @DataframePublicAPI def iter_partitions( - self, results_buffer_size: Optional[int] = NUM_CPUS + self, results_buffer_size: Union[Optional[int], Literal["num_cpus"]] = "num_cpus" ) -> Iterator[Union[MicroPartition, "ray.ObjectRef[MicroPartition]"]]: """Begin executing this dataframe and return an iterator over the partitions. @@ -365,7 +377,9 @@ def iter_partitions( Statistics: missing """ - if results_buffer_size is not None and not results_buffer_size > 0: + if results_buffer_size == "num_cpus": + results_buffer_size = multiprocessing.cpu_count() + elif results_buffer_size is not None and not results_buffer_size > 0: raise ValueError(f"Provided `results_buffer_size` value must be > 0, received: {results_buffer_size}") if self._result is not None: @@ -2496,17 +2510,10 @@ def to_arrow(self) -> "pyarrow.Table": .. NOTE:: This call is **blocking** and will execute the DataFrame when called """ - for name in self.schema().column_names(): - if self.schema()[name].dtype._is_python_type(): - raise ValueError( - f"Cannot convert column {name} to Arrow type, found Python type: {self.schema()[name].dtype}" - ) - - self.collect() - result = self._result - assert result is not None + import pyarrow as pa - return result.to_arrow() + arrow_rb_iter = self.to_arrow_iter(results_buffer_size=None) + return pa.Table.from_batches(arrow_rb_iter, schema=self.schema().to_pyarrow_schema()) @DataframePublicAPI def to_pydict(self) -> Dict[str, List[Any]]: diff --git a/src/daft-core/src/ffi.rs b/src/daft-core/src/ffi.rs index 67ad5fb234..e7084c9a7a 100644 --- a/src/daft-core/src/ffi.rs +++ b/src/daft-core/src/ffi.rs @@ -69,7 +69,7 @@ pub fn field_to_py( Ok(field.to_object(py)) } -pub fn to_py_schema( +pub fn dtype_to_py( dtype: &arrow2::datatypes::DataType, py: Python, pyarrow: &PyModule, @@ -81,8 +81,9 @@ pub fn to_py_schema( pyo3::intern!(py, "_import_from_c"), (schema_ptr as Py_uintptr_t,), )?; + let dtype = field.getattr(pyo3::intern!(py, "type"))?.to_object(py); - Ok(field.to_object(py)) + Ok(dtype.to_object(py)) } fn fix_child_array_slice_offsets(array: ArrayRef) -> ArrayRef { diff --git a/src/daft-core/src/python/datatype.rs b/src/daft-core/src/python/datatype.rs index e50608b05d..e5615bd072 100644 --- a/src/daft-core/src/python/datatype.rs +++ b/src/daft-core/src/python/datatype.rs @@ -331,11 +331,10 @@ impl PyDataType { } else { // Fall back to default Daft super extension representation if installed pyarrow doesn't have the // canonical tensor extension type. - ffi::to_py_schema(&self.dtype.to_arrow()?, py, pyarrow)? + ffi::dtype_to_py(&self.dtype.to_arrow()?, py, pyarrow)? }, ), - _ => ffi::to_py_schema(&self.dtype.to_arrow()?, py, pyarrow)? - .getattr(py, pyo3::intern!(py, "type")), + _ => ffi::dtype_to_py(&self.dtype.to_arrow()?, py, pyarrow), } } diff --git a/src/daft-core/src/python/schema.rs b/src/daft-core/src/python/schema.rs index 33fce18df0..eaa2ff78ac 100644 --- a/src/daft-core/src/python/schema.rs +++ b/src/daft-core/src/python/schema.rs @@ -7,7 +7,6 @@ use serde::{Deserialize, Serialize}; use super::datatype::PyDataType; use super::field::PyField; use crate::datatypes; -use crate::ffi::field_to_py; use crate::schema; use common_py_serde::impl_bincode_py_state_serialization; @@ -29,7 +28,16 @@ impl PySchema { .schema .fields .iter() - .map(|(_, f)| field_to_py(&f.to_arrow()?, py, pyarrow)) + .map(|(_, f)| { + // NOTE: Use PyDataType::to_arrow because we need to dip into Python to get + // the registered Arrow extension types + let py_dtype: PyDataType = f.dtype.clone().into(); + let py_arrow_dtype = py_dtype.to_arrow(py)?; + pyarrow + .getattr(pyo3::intern!(py, "field")) + .unwrap() + .call1((f.name.clone(), py_arrow_dtype)) + }) .collect::>>()?; pyarrow .getattr(pyo3::intern!(py, "schema"))