Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Partitioned Append on Identity Transform #555

Merged
merged 20 commits into from
Apr 5, 2024
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1772,7 +1772,7 @@ def write_file(io: FileIO, table_metadata: TableMetadata, tasks: Iterator[WriteT
)

def write_parquet(task: WriteTask) -> DataFile:
file_path = f'{table_metadata.location}/data/{task.generate_data_file_filename("parquet")}'
file_path = f'{table_metadata.location}/data/{task.generate_data_file_path("parquet")}'
fo = io.new_output(file_path)
with fo.create(overwrite=True) as fos:
with pq.ParquetWriter(fos, schema=arrow_file_schema, **parquet_writer_kwargs) as writer:
Expand All @@ -1787,7 +1787,7 @@ def write_parquet(task: WriteTask) -> DataFile:
content=DataFileContent.DATA,
file_path=file_path,
file_format=FileFormat.PARQUET,
partition=Record(),
partition=task.partition_key.partition if task.partition_key else Record(),
file_size_in_bytes=len(fo),
# After this has been fixed:
# https://github.com/apache/iceberg-python/issues/271
Expand Down
27 changes: 1 addition & 26 deletions pyiceberg/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import math
from abc import ABC, abstractmethod
from enum import Enum
from functools import singledispatch
from types import TracebackType
from typing import (
Any,
Expand All @@ -41,8 +40,6 @@
from pyiceberg.types import (
BinaryType,
BooleanType,
DateType,
IcebergType,
IntegerType,
ListType,
LongType,
Expand All @@ -51,9 +48,6 @@
PrimitiveType,
StringType,
StructType,
TimestampType,
TimestamptzType,
TimeType,
)

UNASSIGNED_SEQ = -1
Expand Down Expand Up @@ -283,31 +277,12 @@ def __repr__(self) -> str:
}


@singledispatch
def partition_field_to_data_file_partition_field(partition_field_type: IcebergType) -> PrimitiveType:
raise TypeError(f"Unsupported partition field type: {partition_field_type}")


@partition_field_to_data_file_partition_field.register(LongType)
@partition_field_to_data_file_partition_field.register(DateType)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This single-dispatch is there only for the TimeType it seems. Probably we should we should also convert those into a native type.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed in the commit 82dd3ad

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Beautiful, thanks 👍

@partition_field_to_data_file_partition_field.register(TimeType)
@partition_field_to_data_file_partition_field.register(TimestampType)
@partition_field_to_data_file_partition_field.register(TimestamptzType)
def _(partition_field_type: PrimitiveType) -> IntegerType:
return IntegerType()


@partition_field_to_data_file_partition_field.register(PrimitiveType)
def _(partition_field_type: PrimitiveType) -> PrimitiveType:
return partition_field_type


def data_file_with_partition(partition_type: StructType, format_version: Literal[1, 2]) -> StructType:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit:

Suggested change
def data_file_with_partition(partition_type: StructType, format_version: Literal[1, 2]) -> StructType:
def data_file_with_partition(partition_type: StructType, format_version: TableVersion) -> StructType:

data_file_partition_type = StructType(*[
NestedField(
field_id=field.field_id,
name=field.name,
field_type=partition_field_to_data_file_partition_field(field.field_type),
field_type=field.field_type,
required=field.required,
)
for field in partition_type.fields
Expand Down
10 changes: 8 additions & 2 deletions pyiceberg/partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import uuid
from abc import ABC, abstractmethod
from dataclasses import dataclass
from datetime import date, datetime
from datetime import date, datetime, time
from functools import cached_property, singledispatch
from typing import (
Any,
Expand Down Expand Up @@ -62,9 +62,10 @@
StructType,
TimestampType,
TimestamptzType,
TimeType,
UUIDType,
)
from pyiceberg.utils.datetime import date_to_days, datetime_to_micros
from pyiceberg.utils.datetime import date_to_days, datetime_to_micros, time_to_micros

INITIAL_PARTITION_SPEC_ID = 0
PARTITION_FIELD_ID_START: int = 1000
Expand Down Expand Up @@ -431,6 +432,11 @@ def _(type: IcebergType, value: Optional[date]) -> Optional[int]:
return date_to_days(value) if value is not None else None


@_to_partition_representation.register(TimeType)
def _(type: IcebergType, value: Optional[time]) -> Optional[int]:
return time_to_micros(value) if value is not None else None


@_to_partition_representation.register(UUIDType)
def _(type: IcebergType, value: Optional[uuid.UUID]) -> Optional[str]:
return str(value) if value is not None else None
Expand Down
174 changes: 155 additions & 19 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@
# under the License.
from __future__ import annotations

import datetime
import itertools
import uuid
import warnings
from abc import ABC, abstractmethod
from copy import copy
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from functools import cached_property, singledispatch
from itertools import chain
Expand Down Expand Up @@ -77,6 +77,8 @@
INITIAL_PARTITION_SPEC_ID,
PARTITION_FIELD_ID_START,
PartitionField,
PartitionFieldValue,
PartitionKey,
PartitionSpec,
_PartitionNameGenerator,
_visit_partition_field,
Expand Down Expand Up @@ -716,7 +718,7 @@ def _(update: SetSnapshotRefUpdate, base_metadata: TableMetadata, context: _Tabl
if update.ref_name == MAIN_BRANCH:
metadata_updates["current_snapshot_id"] = snapshot_ref.snapshot_id
if "last_updated_ms" not in metadata_updates:
metadata_updates["last_updated_ms"] = datetime_to_millis(datetime.datetime.now().astimezone())
metadata_updates["last_updated_ms"] = datetime_to_millis(datetime.now().astimezone())

metadata_updates["snapshot_log"] = base_metadata.snapshot_log + [
SnapshotLogEntry(
Expand Down Expand Up @@ -1131,8 +1133,11 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT)
if not isinstance(df, pa.Table):
raise ValueError(f"Expected PyArrow table, got: {df}")

if len(self.spec().fields) > 0:
raise ValueError("Cannot write to partitioned tables")
supported_transforms = {IdentityTransform}
if not all(type(field.transform) in supported_transforms for field in self.metadata.spec().fields):
raise ValueError(
f"All transforms are not supported, expected: {supported_transforms}, but get: {[str(field) for field in self.metadata.spec().fields if field.transform not in supported_transforms]}."
)

_check_schema_compatible(self.schema(), other_schema=df.schema)
# cast if the two schemas are compatible but not equal
Expand Down Expand Up @@ -2492,16 +2497,23 @@ def _add_and_move_fields(
class WriteTask:
write_uuid: uuid.UUID
task_id: int
schema: Schema
record_batches: List[pa.RecordBatch]
sort_order_id: Optional[int] = None

# Later to be extended with partition information
partition_key: Optional[PartitionKey] = None

def generate_data_file_filename(self, extension: str) -> str:
# Mimics the behavior in the Java API:
# https://github.com/apache/iceberg/blob/a582968975dd30ff4917fbbe999f1be903efac02/core/src/main/java/org/apache/iceberg/io/OutputFileFactory.java#L92-L101
return f"00000-{self.task_id}-{self.write_uuid}.{extension}"

def generate_data_file_path(self, extension: str) -> str:
if self.partition_key:
file_path = f"{self.partition_key.to_path()}/{self.generate_data_file_filename(extension)}"
return file_path
else:
return self.generate_data_file_filename(extension)


@dataclass(frozen=True)
class AddFileTask:
Expand Down Expand Up @@ -2529,25 +2541,40 @@ def _dataframe_to_data_files(
"""
from pyiceberg.io.pyarrow import bin_pack_arrow_table, write_file

if len([spec for spec in table_metadata.partition_specs if spec.spec_id != 0]) > 0:
raise ValueError("Cannot write to partitioned tables")

counter = itertools.count(0)
write_uuid = write_uuid or uuid.uuid4()

target_file_size = PropertyUtil.property_as_int(
target_file_size: int = PropertyUtil.property_as_int( # type: ignore # The property is set with non-None value.
properties=table_metadata.properties,
property_name=TableProperties.WRITE_TARGET_FILE_SIZE_BYTES,
default=TableProperties.WRITE_TARGET_FILE_SIZE_BYTES_DEFAULT,
)

# This is an iter, so we don't have to materialize everything every time
# This will be more relevant when we start doing partitioned writes
yield from write_file(
io=io,
table_metadata=table_metadata,
tasks=iter([WriteTask(write_uuid, next(counter), batches) for batches in bin_pack_arrow_table(df, target_file_size)]), # type: ignore
)
if len(table_metadata.spec().fields) > 0:
partitions = _determine_partitions(spec=table_metadata.spec(), schema=table_metadata.schema(), arrow_table=df)
yield from write_file(
io=io,
table_metadata=table_metadata,
tasks=iter([
WriteTask(
write_uuid=write_uuid,
task_id=next(counter),
record_batches=batches,
partition_key=partition.partition_key,
schema=table_metadata.schema(),
)
for partition in partitions
for batches in bin_pack_arrow_table(partition.arrow_table_partition, target_file_size)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks very nice!

]),
)
else:
yield from write_file(
io=io,
table_metadata=table_metadata,
tasks=iter([
WriteTask(write_uuid=write_uuid, task_id=next(counter), record_batches=batches, schema=table_metadata.schema())
for batches in bin_pack_arrow_table(df, target_file_size)
]),
)


def _parquet_files_to_data_files(table_metadata: TableMetadata, file_paths: List[str], io: FileIO) -> Iterable[DataFile]:
Expand Down Expand Up @@ -3099,7 +3126,7 @@ def snapshots(self) -> "pa.Table":
additional_properties = None

snapshots.append({
'committed_at': datetime.datetime.utcfromtimestamp(snapshot.timestamp_ms / 1000.0),
'committed_at': datetime.utcfromtimestamp(snapshot.timestamp_ms / 1000.0),
'snapshot_id': snapshot.snapshot_id,
'parent_id': snapshot.parent_snapshot_id,
'operation': str(operation),
Expand All @@ -3111,3 +3138,112 @@ def snapshots(self) -> "pa.Table":
snapshots,
schema=snapshots_schema,
)


@dataclass(frozen=True)
class TablePartition:
partition_key: PartitionKey
arrow_table_partition: pa.Table


def _get_partition_sort_order(partition_columns: list[str], reverse: bool = False) -> dict[str, Any]:
order = 'ascending' if not reverse else 'descending'
null_placement = 'at_start' if reverse else 'at_end'
return {'sort_keys': [(column_name, order) for column_name in partition_columns], 'null_placement': null_placement}


def group_by_partition_scheme(arrow_table: pa.Table, partition_columns: list[str]) -> pa.Table:
"""Given a table, sort it by current partition scheme."""
# only works for identity for now
sort_options = _get_partition_sort_order(partition_columns, reverse=False)
sorted_arrow_table = arrow_table.sort_by(sorting=sort_options['sort_keys'], null_placement=sort_options['null_placement'])
return sorted_arrow_table


def get_partition_columns(
spec: PartitionSpec,
schema: Schema,
) -> list[str]:
partition_cols = []
for partition_field in spec.fields:
column_name = schema.find_column_name(partition_field.source_id)
if not column_name:
raise ValueError(f"{partition_field=} could not be found in {schema}.")
partition_cols.append(column_name)
return partition_cols


def _get_table_partitions(
arrow_table: pa.Table,
partition_spec: PartitionSpec,
schema: Schema,
slice_instructions: list[dict[str, Any]],
) -> list[TablePartition]:
sorted_slice_instructions = sorted(slice_instructions, key=lambda x: x['offset'])

partition_fields = partition_spec.fields

offsets = [inst["offset"] for inst in sorted_slice_instructions]
projected_and_filtered = {
partition_field.source_id: arrow_table[schema.find_field(name_or_id=partition_field.source_id).name]
.take(offsets)
.to_pylist()
for partition_field in partition_fields
}

table_partitions = []
for idx, inst in enumerate(sorted_slice_instructions):
partition_slice = arrow_table.slice(**inst)
fieldvalues = [
PartitionFieldValue(partition_field, projected_and_filtered[partition_field.source_id][idx])
for partition_field in partition_fields
]
partition_key = PartitionKey(raw_partition_field_values=fieldvalues, partition_spec=partition_spec, schema=schema)
table_partitions.append(TablePartition(partition_key=partition_key, arrow_table_partition=partition_slice))
return table_partitions


def _determine_partitions(spec: PartitionSpec, schema: Schema, arrow_table: pa.Table) -> List[TablePartition]:
"""Based on the iceberg table partition spec, slice the arrow table into partitions with their keys.

Example:
Input:
An arrow table with partition key of ['n_legs', 'year'] and with data of
{'year': [2020, 2022, 2022, 2021, 2022, 2022, 2022, 2019, 2021],
'n_legs': [2, 2, 2, 4, 4, 4, 4, 5, 100],
'animal': ["Flamingo", "Parrot", "Parrot", "Dog", "Horse", "Horse", "Horse","Brittle stars", "Centipede"]}.
The algrithm:
Firstly we group the rows into partitions by sorting with sort order [('n_legs', 'descending'), ('year', 'descending')]
and null_placement of "at_end".
This gives the same table as raw input.
Then we sort_indices using reverse order of [('n_legs', 'descending'), ('year', 'descending')]
and null_placement : "at_start".
This gives:
[8, 7, 4, 5, 6, 3, 1, 2, 0]
Based on this we get partition groups of indices:
[{'offset': 8, 'length': 1}, {'offset': 7, 'length': 1}, {'offset': 4, 'length': 3}, {'offset': 3, 'length': 1}, {'offset': 1, 'length': 2}, {'offset': 0, 'length': 1}]
We then retrieve the partition keys by offsets.
And slice the arrow table by offsets and lengths of each partition.
"""
import pyarrow as pa

partition_columns = get_partition_columns(spec=spec, schema=schema)
arrow_table = group_by_partition_scheme(arrow_table, partition_columns)

reversing_sort_order_options = _get_partition_sort_order(partition_columns, reverse=True)
reversed_indices = pa.compute.sort_indices(arrow_table, **reversing_sort_order_options).to_pylist()

slice_instructions: list[dict[str, Any]] = []
last = len(reversed_indices)
reversed_indices_size = len(reversed_indices)
ptr = 0
while ptr < reversed_indices_size:
group_size = last - reversed_indices[ptr]
offset = reversed_indices[ptr]
slice_instructions.append({"offset": offset, "length": group_size})
last = reversed_indices[ptr]
ptr = ptr + group_size

table_partitions: list[TablePartition] = _get_table_partitions(arrow_table, spec, schema, slice_instructions)

return table_partitions
4 changes: 4 additions & 0 deletions pyiceberg/typedef.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,3 +199,7 @@ def __repr__(self) -> str:
def record_fields(self) -> List[str]:
"""Return values of all the fields of the Record class except those specified in skip_fields."""
return [self.__getattribute__(v) if hasattr(self, v) else None for v in self._position_to_field_name]

def __hash__(self) -> int:
"""Return hash value of the Record class."""
return hash(str(self))
Fokko marked this conversation as resolved.
Show resolved Hide resolved
Loading