Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 17 additions & 24 deletions pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -2790,30 +2790,26 @@ def _dataframe_to_data_files(
yield from write_file(
io=io,
table_metadata=table_metadata,
tasks=iter(
[
WriteTask(write_uuid=write_uuid, task_id=next(counter), record_batches=batches, schema=task_schema)
for batches in bin_pack_arrow_table(df, target_file_size)
]
tasks=(
WriteTask(write_uuid=write_uuid, task_id=next(counter), record_batches=batches, schema=task_schema)
for batches in bin_pack_arrow_table(df, target_file_size)
),
)
else:
partitions = _determine_partitions(spec=table_metadata.spec(), schema=table_metadata.schema(), arrow_table=df)
yield from write_file(
io=io,
table_metadata=table_metadata,
tasks=iter(
[
WriteTask(
write_uuid=write_uuid,
task_id=next(counter),
record_batches=batches,
partition_key=partition.partition_key,
schema=task_schema,
)
for partition in partitions
for batches in bin_pack_arrow_table(partition.arrow_table_partition, target_file_size)
]
tasks=(
WriteTask(
write_uuid=write_uuid,
task_id=next(counter),
record_batches=batches,
partition_key=partition.partition_key,
schema=task_schema,
)
for partition in partitions
for batches in bin_pack_arrow_table(partition.arrow_table_partition, target_file_size)
),
)

Expand All @@ -2824,7 +2820,7 @@ class _TablePartition:
arrow_table_partition: pa.Table


def _determine_partitions(spec: PartitionSpec, schema: Schema, arrow_table: pa.Table) -> List[_TablePartition]:
def _determine_partitions(spec: PartitionSpec, schema: Schema, arrow_table: pa.Table) -> Iterable[_TablePartition]:
"""Based on the iceberg table partition spec, filter the arrow table into partitions with their keys.

Example:
Expand Down Expand Up @@ -2852,8 +2848,6 @@ def _determine_partitions(spec: PartitionSpec, schema: Schema, arrow_table: pa.T

unique_partition_fields = arrow_table.select(partition_fields).group_by(partition_fields).aggregate([])

table_partitions = []
# TODO: As a next step, we could also play around with yielding instead of materializing the full list
for unique_partition in unique_partition_fields.to_pylist():
partition_key = PartitionKey(
field_values=[
Expand All @@ -2880,12 +2874,11 @@ def _determine_partitions(spec: PartitionSpec, schema: Schema, arrow_table: pa.T

# The combine_chunks seems to be counter-intuitive to do, but it actually returns
# fresh buffers that don't interfere with each other when it is written out to file
table_partitions.append(
_TablePartition(partition_key=partition_key, arrow_table_partition=filtered_table.combine_chunks())
yield _TablePartition(
partition_key=partition_key,
arrow_table_partition=filtered_table.combine_chunks(),
)

return table_partitions


def _get_field_from_arrow_table(arrow_table: pa.Table, field_path: str) -> pa.Array:
"""Get a field from an Arrow table, supporting both literal field names and nested field paths.
Expand Down
8 changes: 4 additions & 4 deletions tests/io/test_pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -2479,7 +2479,7 @@ def test_partition_for_demo() -> None:
PartitionField(source_id=2, field_id=1002, transform=IdentityTransform(), name="n_legs_identity"),
PartitionField(source_id=1, field_id=1001, transform=IdentityTransform(), name="year_identity"),
)
result = _determine_partitions(partition_spec, test_schema, arrow_table)
result = list(_determine_partitions(partition_spec, test_schema, arrow_table))
assert {table_partition.partition_key.partition for table_partition in result} == {
Record(2, 2020),
Record(100, 2021),
Expand Down Expand Up @@ -2518,7 +2518,7 @@ def test_partition_for_nested_field() -> None:
]

arrow_table = pa.Table.from_pylist(test_data, schema=schema.as_arrow())
partitions = _determine_partitions(spec, schema, arrow_table)
partitions = list(_determine_partitions(spec, schema, arrow_table))
partition_values = {p.partition_key.partition[0] for p in partitions}

assert partition_values == {486729, 486730}
Expand Down Expand Up @@ -2550,7 +2550,7 @@ def test_partition_for_deep_nested_field() -> None:
]

arrow_table = pa.Table.from_pylist(test_data, schema=schema.as_arrow())
partitions = _determine_partitions(spec, schema, arrow_table)
partitions = list(_determine_partitions(spec, schema, arrow_table))

assert len(partitions) == 2 # 2 unique partitions
partition_values = {p.partition_key.partition[0] for p in partitions}
Expand Down Expand Up @@ -2621,7 +2621,7 @@ def test_identity_partition_on_multi_columns() -> None:
}
arrow_table = pa.Table.from_pydict(test_data, schema=test_pa_schema)

result = _determine_partitions(partition_spec, test_schema, arrow_table)
result = list(_determine_partitions(partition_spec, test_schema, arrow_table))

assert {table_partition.partition_key.partition for table_partition in result} == expected
concatenated_arrow_table = pa.concat_tables([table_partition.arrow_table_partition for table_partition in result])
Expand Down