Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Partitioned Append on Identity Transform #555

Merged
merged 20 commits into from
Apr 5, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1772,7 +1772,7 @@ def write_file(io: FileIO, table_metadata: TableMetadata, tasks: Iterator[WriteT
)

def write_parquet(task: WriteTask) -> DataFile:
file_path = f'{table_metadata.location}/data/{task.generate_data_file_filename("parquet")}'
file_path = f'{table_metadata.location}/data/{task.generate_data_file_path("parquet")}'
fo = io.new_output(file_path)
with fo.create(overwrite=True) as fos:
with pq.ParquetWriter(fos, schema=arrow_file_schema, **parquet_writer_kwargs) as writer:
Expand All @@ -1787,7 +1787,7 @@ def write_parquet(task: WriteTask) -> DataFile:
content=DataFileContent.DATA,
file_path=file_path,
file_format=FileFormat.PARQUET,
partition=Record(),
partition=task.partition_key.partition if task.partition_key else Record(),
file_size_in_bytes=len(fo),
# After this has been fixed:
# https://github.com/apache/iceberg-python/issues/271
Expand Down
6 changes: 0 additions & 6 deletions pyiceberg/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
from pyiceberg.types import (
BinaryType,
BooleanType,
DateType,
IcebergType,
IntegerType,
ListType,
Expand All @@ -51,8 +50,6 @@
PrimitiveType,
StringType,
StructType,
TimestampType,
TimestamptzType,
TimeType,
)

Expand Down Expand Up @@ -289,10 +286,7 @@ def partition_field_to_data_file_partition_field(partition_field_type: IcebergTy


@partition_field_to_data_file_partition_field.register(LongType)
@partition_field_to_data_file_partition_field.register(DateType)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This single-dispatch is there only for the TimeType it seems. Probably we should we should also convert those into a native type.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed in the commit 82dd3ad

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Beautiful, thanks 👍

@partition_field_to_data_file_partition_field.register(TimeType)
@partition_field_to_data_file_partition_field.register(TimestampType)
@partition_field_to_data_file_partition_field.register(TimestamptzType)
def _(partition_field_type: PrimitiveType) -> IntegerType:
return IntegerType()

Expand Down
179 changes: 162 additions & 17 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@
# under the License.
from __future__ import annotations

import datetime
import itertools
import uuid
import warnings
from abc import ABC, abstractmethod
from copy import copy
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from functools import cached_property, singledispatch
from itertools import chain
Expand Down Expand Up @@ -77,6 +77,8 @@
INITIAL_PARTITION_SPEC_ID,
PARTITION_FIELD_ID_START,
PartitionField,
PartitionFieldValue,
PartitionKey,
PartitionSpec,
_PartitionNameGenerator,
_visit_partition_field,
Expand Down Expand Up @@ -716,7 +718,7 @@ def _(update: SetSnapshotRefUpdate, base_metadata: TableMetadata, context: _Tabl
if update.ref_name == MAIN_BRANCH:
metadata_updates["current_snapshot_id"] = snapshot_ref.snapshot_id
if "last_updated_ms" not in metadata_updates:
metadata_updates["last_updated_ms"] = datetime_to_millis(datetime.datetime.now().astimezone())
metadata_updates["last_updated_ms"] = datetime_to_millis(datetime.now().astimezone())

metadata_updates["snapshot_log"] = base_metadata.snapshot_log + [
SnapshotLogEntry(
Expand Down Expand Up @@ -1131,8 +1133,11 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT)
if not isinstance(df, pa.Table):
raise ValueError(f"Expected PyArrow table, got: {df}")

if len(self.spec().fields) > 0:
raise ValueError("Cannot write to partitioned tables")
supported = {IdentityTransform}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit:

Suggested change
supported = {IdentityTransform}
supported_transforms = {IdentityTransform}

if not all(type(field.transform) in supported for field in self.metadata.spec().fields):
raise ValueError(
f"All transforms are not supported, expected: {supported}, but get: {[str(field) for field in self.metadata.spec().fields if field.transform not in supported]}."
)

_check_schema_compatible(self.schema(), other_schema=df.schema)
# cast if the two schemas are compatible but not equal
Expand Down Expand Up @@ -2492,16 +2497,28 @@ def _add_and_move_fields(
class WriteTask:
write_uuid: uuid.UUID
task_id: int
schema: Schema
record_batches: List[pa.RecordBatch]
sort_order_id: Optional[int] = None
partition_key: Optional[PartitionKey] = None

# Later to be extended with partition information
def generate_data_file_partition_path(self) -> str:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: This function looks redundant. The check is being done in generate_data_file_path() as well. I would merge those two.

if self.partition_key is None:
raise ValueError("Cannot generate partition path based on non-partitioned WriteTask")
return self.partition_key.to_path()

def generate_data_file_filename(self, extension: str) -> str:
# Mimics the behavior in the Java API:
# https://github.com/apache/iceberg/blob/a582968975dd30ff4917fbbe999f1be903efac02/core/src/main/java/org/apache/iceberg/io/OutputFileFactory.java#L92-L101
return f"00000-{self.task_id}-{self.write_uuid}.{extension}"

def generate_data_file_path(self, extension: str) -> str:
if self.partition_key:
file_path = f"{self.generate_data_file_partition_path()}/{self.generate_data_file_filename(extension)}"
return file_path
else:
return self.generate_data_file_filename(extension)


@dataclass(frozen=True)
class AddFileTask:
Expand Down Expand Up @@ -2529,25 +2546,44 @@ def _dataframe_to_data_files(
"""
from pyiceberg.io.pyarrow import bin_pack_arrow_table, write_file

if len([spec for spec in table_metadata.partition_specs if spec.spec_id != 0]) > 0:
raise ValueError("Cannot write to partitioned tables")

counter = itertools.count(0)
write_uuid = write_uuid or uuid.uuid4()

target_file_size = PropertyUtil.property_as_int(
properties=table_metadata.properties,
property_name=TableProperties.WRITE_TARGET_FILE_SIZE_BYTES,
default=TableProperties.WRITE_TARGET_FILE_SIZE_BYTES_DEFAULT,
)
if target_file_size is None:
raise ValueError(
"Fail to get neither TableProperties.WRITE_TARGET_FILE_SIZE_BYTES nor WRITE_TARGET_FILE_SIZE_BYTES_DEFAULT for writing target data file."
Copy link
Collaborator

@sungwy sungwy Mar 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have mixed feelings about this exception check, because we are setting the default value of target_file_size as TableProperties.WRITE_TARGET_FILE_SIZE_BYTES_DEFAULT right in the previous line. I feel as though this is too redundant.

I understand why we are doing it though:

PropertyUtil.property_as_int returns Optional[int], and bin_packing expects an int, so we need to type check it.

If we run into more of these type checking redundancies in the code base, where when we are using property values that are always expected to have a none-null default value, maybe we should refactor PropertyUtil instead. Maybe we can have two methods, property_as_int that returns an Optional[int], and property_as_int_with_default, that returns an int?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

property_as_int_with_default sounds better to me because all the exceptions raised due to missing default property could be centralized in the function? How do you feel about it

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like that as well, the ValueError is misleading and it is not directly obvious why we would raise it.

Copy link
Contributor Author

@jqin61 jqin61 Apr 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i just find the default value itself could be None:
PARQUET_COMPRESSION_LEVEL_DEFAULT = None
so this None checking is not unnecessary?

the original code for this target_file_size check just type: ignores it

)

# This is an iter, so we don't have to materialize everything every time
# This will be more relevant when we start doing partitioned writes
yield from write_file(
io=io,
table_metadata=table_metadata,
tasks=iter([WriteTask(write_uuid, next(counter), batches) for batches in bin_pack_arrow_table(df, target_file_size)]), # type: ignore
)
if len(table_metadata.spec().fields) > 0:
partitions = partition(spec=table_metadata.spec(), schema=table_metadata.schema(), arrow_table=df)
yield from write_file(
io=io,
table_metadata=table_metadata,
tasks=iter([
WriteTask(
write_uuid=write_uuid,
task_id=next(counter),
record_batches=batches,
partition_key=partition.partition_key,
schema=table_metadata.schema(),
)
for partition in partitions
for batches in bin_pack_arrow_table(partition.arrow_table_partition, target_file_size)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks very nice!

]),
)
else:
yield from write_file(
io=io,
table_metadata=table_metadata,
tasks=iter([
WriteTask(write_uuid=write_uuid, task_id=next(counter), record_batches=batches, schema=table_metadata.schema())
for batches in bin_pack_arrow_table(df, target_file_size)
]),
)


def _parquet_files_to_data_files(table_metadata: TableMetadata, file_paths: List[str], io: FileIO) -> Iterable[DataFile]:
Expand Down Expand Up @@ -3099,7 +3135,7 @@ def snapshots(self) -> "pa.Table":
additional_properties = None

snapshots.append({
'committed_at': datetime.datetime.utcfromtimestamp(snapshot.timestamp_ms / 1000.0),
'committed_at': datetime.utcfromtimestamp(snapshot.timestamp_ms / 1000.0),
'snapshot_id': snapshot.snapshot_id,
'parent_id': snapshot.parent_snapshot_id,
'operation': str(operation),
Expand All @@ -3111,3 +3147,112 @@ def snapshots(self) -> "pa.Table":
snapshots,
schema=snapshots_schema,
)


@dataclass(frozen=True)
class TablePartition:
partition_key: PartitionKey
arrow_table_partition: pa.Table


def _get_partition_sort_order(partition_columns: list[str], reverse: bool = False) -> dict[str, Any]:
order = 'ascending' if not reverse else 'descending'
null_placement = 'at_start' if reverse else 'at_end'
return {'sort_keys': [(column_name, order) for column_name in partition_columns], 'null_placement': null_placement}


def group_by_partition_scheme(arrow_table: pa.Table, partition_columns: list[str]) -> pa.Table:
"""Given a table, sort it by current partition scheme."""
# only works for identity for now
sort_options = _get_partition_sort_order(partition_columns, reverse=False)
sorted_arrow_table = arrow_table.sort_by(sorting=sort_options['sort_keys'], null_placement=sort_options['null_placement'])
return sorted_arrow_table


def get_partition_columns(
spec: PartitionSpec,
schema: Schema,
) -> list[str]:
partition_cols = []
for partition_field in spec.fields:
column_name = schema.find_column_name(partition_field.source_id)
if not column_name:
raise ValueError(f"{partition_field=} could not be found in {schema}.")
partition_cols.append(column_name)
return partition_cols


def _get_table_partitions(
arrow_table: pa.Table,
partition_spec: PartitionSpec,
schema: Schema,
slice_instructions: list[dict[str, Any]],
) -> list[TablePartition]:
sorted_slice_instructions = sorted(slice_instructions, key=lambda x: x['offset'])

partition_fields = partition_spec.fields

offsets = [inst["offset"] for inst in sorted_slice_instructions]
projected_and_filtered = {
partition_field.source_id: arrow_table[schema.find_field(name_or_id=partition_field.source_id).name]
.take(offsets)
.to_pylist()
for partition_field in partition_fields
}

table_partitions = []
for idx, inst in enumerate(sorted_slice_instructions):
partition_slice = arrow_table.slice(**inst)
fieldvalues = [
PartitionFieldValue(partition_field, projected_and_filtered[partition_field.source_id][idx])
for partition_field in partition_fields
]
partition_key = PartitionKey(raw_partition_field_values=fieldvalues, partition_spec=partition_spec, schema=schema)
table_partitions.append(TablePartition(partition_key=partition_key, arrow_table_partition=partition_slice))
return table_partitions


def partition(spec: PartitionSpec, schema: Schema, arrow_table: pa.Table) -> Iterable[TablePartition]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be good to have a bit more length filenames. I also think we should hide this from the outside user.

Suggested change
def partition(spec: PartitionSpec, schema: Schema, arrow_table: pa.Table) -> Iterable[TablePartition]:
def _determine_partitions(spec: PartitionSpec, schema: Schema, arrow_table: pa.Table) -> List[TablePartition]:

I think we can also return a list, so folks know that it is already materialized.

"""Based on the iceberg table partition spec, slice the arrow table into partitions with their keys.

Example:
Input:
An arrow table with partition key of ['n_legs', 'year'] and with data of
{'year': [2020, 2022, 2022, 2021, 2022, 2022, 2022, 2019, 2021],
'n_legs': [2, 2, 2, 4, 4, 4, 4, 5, 100],
'animal': ["Flamingo", "Parrot", "Parrot", "Dog", "Horse", "Horse", "Horse","Brittle stars", "Centipede"]}.
The algrithm:
Firstly we group the rows into partitions by sorting with sort order [('n_legs', 'descending'), ('year', 'descending')]
and null_placement of "at_end".
This gives the same table as raw input.
Then we sort_indices using reverse order of [('n_legs', 'descending'), ('year', 'descending')]
and null_placement : "at_start".
This gives:
[8, 7, 4, 5, 6, 3, 1, 2, 0]
Based on this we get partition groups of indices:
[{'offset': 8, 'length': 1}, {'offset': 7, 'length': 1}, {'offset': 4, 'length': 3}, {'offset': 3, 'length': 1}, {'offset': 1, 'length': 2}, {'offset': 0, 'length': 1}]
We then retrieve the partition keys by offsets.
And slice the arrow table by offsets and lengths of each partition.
"""
import pyarrow as pa

partition_columns = get_partition_columns(spec=spec, schema=schema)
arrow_table = group_by_partition_scheme(arrow_table, partition_columns)

reversing_sort_order_options = _get_partition_sort_order(partition_columns, reverse=True)
reversed_indices = pa.compute.sort_indices(arrow_table, **reversing_sort_order_options).to_pylist()

slice_instructions: list[dict[str, Any]] = []
last = len(reversed_indices)
reversed_indices_size = len(reversed_indices)
ptr = 0
while ptr < reversed_indices_size:
group_size = last - reversed_indices[ptr]
offset = reversed_indices[ptr]
slice_instructions.append({"offset": offset, "length": group_size})
last = reversed_indices[ptr]
ptr = ptr + group_size

table_partitions: list[TablePartition] = _get_table_partitions(arrow_table, spec, schema, slice_instructions)

return table_partitions
4 changes: 4 additions & 0 deletions pyiceberg/typedef.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,3 +199,7 @@ def __repr__(self) -> str:
def record_fields(self) -> List[str]:
"""Return values of all the fields of the Record class except those specified in skip_fields."""
return [self.__getattribute__(v) if hasattr(self, v) else None for v in self._position_to_field_name]

def __hash__(self) -> int:
"""Return hash value of the Record class."""
return hash(str(self))
Fokko marked this conversation as resolved.
Show resolved Hide resolved
24 changes: 22 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import socket
import string
import uuid
from datetime import date, datetime
from datetime import date, datetime, timezone
from pathlib import Path
from random import choice
from tempfile import TemporaryDirectory
Expand Down Expand Up @@ -2000,7 +2000,11 @@ def spark() -> "SparkSession":
'float': [0.0, None, 0.9],
'double': [0.0, None, 0.9],
'timestamp': [datetime(2023, 1, 1, 19, 25, 00), None, datetime(2023, 3, 1, 19, 25, 00)],
'timestamptz': [datetime(2023, 1, 1, 19, 25, 00), None, datetime(2023, 3, 1, 19, 25, 00)],
'timestamptz': [
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice one!

datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc),
None,
datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc),
],
'date': [date(2023, 1, 1), None, date(2023, 3, 1)],
# Not supported by Spark
# 'time': [time(1, 22, 0), None, time(19, 25, 0)],
Expand Down Expand Up @@ -2045,3 +2049,19 @@ def arrow_table_with_null(pa_schema: "pa.Schema") -> "pa.Table":

"""PyArrow table with all kinds of columns"""
return pa.Table.from_pydict(TEST_DATA_WITH_NULL, schema=pa_schema)


@pytest.fixture(scope="session")
def arrow_table_without_data(pa_schema: "pa.Schema") -> "pa.Table":
import pyarrow as pa

"""PyArrow table with all kinds of columns."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
import pyarrow as pa
"""PyArrow table with all kinds of columns."""
"""PyArrow table with all kinds of columns."""
import pyarrow as pa

return pa.Table.from_pylist([], schema=pa_schema)


@pytest.fixture(scope="session")
def arrow_table_with_only_nulls(pa_schema: "pa.Schema") -> "pa.Table":
import pyarrow as pa

"""PyArrow table with all kinds of columns."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
import pyarrow as pa
"""PyArrow table with all kinds of columns."""
"""PyArrow table with all kinds of columns."""
import pyarrow as pa

return pa.Table.from_pylist([{}, {}], schema=pa_schema)
2 changes: 1 addition & 1 deletion tests/integration/test_partitioning_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,7 +749,7 @@ def test_partition_key(
# key.to_path() generates the hive partitioning part of the to-write parquet file path
assert key.to_path() == expected_hive_partition_path_slice

# Justify expected values are not made up but conform to spark behaviors
# Justify expected values are not made up but conforming to spark behaviors
if spark_create_table_sql_for_justification is not None and spark_data_insert_sql_for_justification is not None:
try:
spark.sql(f"drop table {identifier}")
Expand Down
Loading
Loading