Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG]: use recordbatch instead of table for df.to_arrow_iter #2724

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions daft/dataframe/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,16 +286,16 @@ def iter_rows(self, results_buffer_size: Optional[int] = NUM_CPUS) -> Iterator[D
yield row

@DataframePublicAPI
def to_arrow_iter(self, results_buffer_size: Optional[int] = 1) -> Iterator["pyarrow.Table"]:
def to_arrow_iter(self, results_buffer_size: Optional[int] = 1) -> Iterator["pyarrow.RecordBatch"]:
"""
Return an iterator of pyarrow tables for this dataframe.
Return an iterator of pyarrow recordbatches for this dataframe.
"""
if results_buffer_size is not None and not results_buffer_size > 0:
raise ValueError(f"Provided `results_buffer_size` value must be > 0, received: {results_buffer_size}")
if self._result is not None:
# If the dataframe has already finished executing,
# use the precomputed results.
yield self.to_arrow()
yield from self.to_arrow().to_batches()

else:
# Execute the dataframe in a streaming fashion.
Expand All @@ -304,7 +304,7 @@ def to_arrow_iter(self, results_buffer_size: Optional[int] = 1) -> Iterator["pya

# Iterate through partitions.
for partition in partitions_iter:
yield partition.to_arrow()
yield from partition.to_arrow().to_batches()

@DataframePublicAPI
def iter_partitions(
Expand Down
2 changes: 1 addition & 1 deletion tests/table/test_from_py.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,4 +669,4 @@ def __iter__(self):
def test_to_arrow_iterator() -> None:
df = daft.from_pydict({"a": [1, 2, 3], "b": [4, 5, 6]})
it = df.to_arrow_iter()
assert isinstance(next(it), pa.Table)
assert isinstance(next(it), pa.RecordBatch)
Loading