diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index e42c130779..7710df76f8 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -2790,11 +2790,9 @@ def _dataframe_to_data_files( yield from write_file( io=io, table_metadata=table_metadata, - tasks=iter( - [ - WriteTask(write_uuid=write_uuid, task_id=next(counter), record_batches=batches, schema=task_schema) - for batches in bin_pack_arrow_table(df, target_file_size) - ] + tasks=( + WriteTask(write_uuid=write_uuid, task_id=next(counter), record_batches=batches, schema=task_schema) + for batches in bin_pack_arrow_table(df, target_file_size) ), ) else: @@ -2802,18 +2800,16 @@ def _dataframe_to_data_files( yield from write_file( io=io, table_metadata=table_metadata, - tasks=iter( - [ - WriteTask( - write_uuid=write_uuid, - task_id=next(counter), - record_batches=batches, - partition_key=partition.partition_key, - schema=task_schema, - ) - for partition in partitions - for batches in bin_pack_arrow_table(partition.arrow_table_partition, target_file_size) - ] + tasks=( + WriteTask( + write_uuid=write_uuid, + task_id=next(counter), + record_batches=batches, + partition_key=partition.partition_key, + schema=task_schema, + ) + for partition in partitions + for batches in bin_pack_arrow_table(partition.arrow_table_partition, target_file_size) ), ) @@ -2824,7 +2820,7 @@ class _TablePartition: arrow_table_partition: pa.Table -def _determine_partitions(spec: PartitionSpec, schema: Schema, arrow_table: pa.Table) -> List[_TablePartition]: +def _determine_partitions(spec: PartitionSpec, schema: Schema, arrow_table: pa.Table) -> Iterable[_TablePartition]: """Based on the iceberg table partition spec, filter the arrow table into partitions with their keys. Example: @@ -2852,8 +2848,6 @@ def _determine_partitions(spec: PartitionSpec, schema: Schema, arrow_table: pa.T unique_partition_fields = arrow_table.select(partition_fields).group_by(partition_fields).aggregate([]) - table_partitions = [] - # TODO: As a next step, we could also play around with yielding instead of materializing the full list for unique_partition in unique_partition_fields.to_pylist(): partition_key = PartitionKey( field_values=[ @@ -2880,12 +2874,11 @@ def _determine_partitions(spec: PartitionSpec, schema: Schema, arrow_table: pa.T # The combine_chunks seems to be counter-intuitive to do, but it actually returns # fresh buffers that don't interfere with each other when it is written out to file - table_partitions.append( - _TablePartition(partition_key=partition_key, arrow_table_partition=filtered_table.combine_chunks()) + yield _TablePartition( + partition_key=partition_key, + arrow_table_partition=filtered_table.combine_chunks(), ) - return table_partitions - def _get_field_from_arrow_table(arrow_table: pa.Table, field_path: str) -> pa.Array: """Get a field from an Arrow table, supporting both literal field names and nested field paths. diff --git a/tests/io/test_pyarrow.py b/tests/io/test_pyarrow.py index a19ddd607d..45b9d9c901 100644 --- a/tests/io/test_pyarrow.py +++ b/tests/io/test_pyarrow.py @@ -2479,7 +2479,7 @@ def test_partition_for_demo() -> None: PartitionField(source_id=2, field_id=1002, transform=IdentityTransform(), name="n_legs_identity"), PartitionField(source_id=1, field_id=1001, transform=IdentityTransform(), name="year_identity"), ) - result = _determine_partitions(partition_spec, test_schema, arrow_table) + result = list(_determine_partitions(partition_spec, test_schema, arrow_table)) assert {table_partition.partition_key.partition for table_partition in result} == { Record(2, 2020), Record(100, 2021), @@ -2518,7 +2518,7 @@ def test_partition_for_nested_field() -> None: ] arrow_table = pa.Table.from_pylist(test_data, schema=schema.as_arrow()) - partitions = _determine_partitions(spec, schema, arrow_table) + partitions = list(_determine_partitions(spec, schema, arrow_table)) partition_values = {p.partition_key.partition[0] for p in partitions} assert partition_values == {486729, 486730} @@ -2550,7 +2550,7 @@ def test_partition_for_deep_nested_field() -> None: ] arrow_table = pa.Table.from_pylist(test_data, schema=schema.as_arrow()) - partitions = _determine_partitions(spec, schema, arrow_table) + partitions = list(_determine_partitions(spec, schema, arrow_table)) assert len(partitions) == 2 # 2 unique partitions partition_values = {p.partition_key.partition[0] for p in partitions} @@ -2621,7 +2621,7 @@ def test_identity_partition_on_multi_columns() -> None: } arrow_table = pa.Table.from_pydict(test_data, schema=test_pa_schema) - result = _determine_partitions(partition_spec, test_schema, arrow_table) + result = list(_determine_partitions(partition_spec, test_schema, arrow_table)) assert {table_partition.partition_key.partition for table_partition in result} == expected concatenated_arrow_table = pa.concat_tables([table_partition.arrow_table_partition for table_partition in result])