Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make the snapshot creation part of the Transaction #446

Merged
merged 2 commits into from
Feb 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1714,7 +1714,7 @@ def fill_parquet_file_metadata(
data_file.split_offsets = split_offsets


def write_file(table: Table, tasks: Iterator[WriteTask]) -> Iterator[DataFile]:
def write_file(table: Table, tasks: Iterator[WriteTask], file_schema: Optional[Schema] = None) -> Iterator[DataFile]:
task = next(tasks)

try:
Expand All @@ -1727,7 +1727,8 @@ def write_file(table: Table, tasks: Iterator[WriteTask]) -> Iterator[DataFile]:
parquet_writer_kwargs = _get_parquet_writer_kwargs(table.properties)

file_path = f'{table.location()}/data/{task.generate_data_file_filename("parquet")}'
file_schema = schema_to_pyarrow(table.schema())
file_schema = file_schema or table.schema()
arrow_file_schema = schema_to_pyarrow(file_schema)

fo = table.io.new_output(file_path)
row_group_size = PropertyUtil.property_as_int(
Expand All @@ -1736,7 +1737,7 @@ def write_file(table: Table, tasks: Iterator[WriteTask]) -> Iterator[DataFile]:
default=TableProperties.PARQUET_ROW_GROUP_SIZE_BYTES_DEFAULT,
)
with fo.create(overwrite=True) as fos:
with pq.ParquetWriter(fos, schema=file_schema, **parquet_writer_kwargs) as writer:
with pq.ParquetWriter(fos, schema=arrow_file_schema, **parquet_writer_kwargs) as writer:
writer.write_table(task.df, row_group_size=row_group_size)

data_file = DataFile(
Expand All @@ -1758,8 +1759,8 @@ def write_file(table: Table, tasks: Iterator[WriteTask]) -> Iterator[DataFile]:
fill_parquet_file_metadata(
data_file=data_file,
parquet_metadata=writer.writer.metadata,
stats_columns=compute_statistics_plan(table.schema(), table.properties),
parquet_column_mapping=parquet_path_to_id_mapping(table.schema()),
stats_columns=compute_statistics_plan(file_schema, table.properties),
parquet_column_mapping=parquet_path_to_id_mapping(file_schema),
)
return iter([data_file])

Expand Down
251 changes: 165 additions & 86 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,14 @@ def update_schema(self) -> UpdateSchema:
"""
return UpdateSchema(self._table, self)

def update_snapshot(self) -> UpdateSnapshot:
"""Create a new UpdateSnapshot to produce a new snapshot for the table.

Returns:
A new UpdateSnapshot
"""
return UpdateSnapshot(self._table, self)

def remove_properties(self, *removals: str) -> Transaction:
"""Remove properties.

Expand All @@ -351,6 +359,12 @@ def update_location(self, location: str) -> Transaction:
"""
raise NotImplementedError("Not yet implemented")

def schema(self) -> Schema:
try:
return next(update for update in self._updates if isinstance(update, AddSchemaUpdate)).schema_
except StopIteration:
return self._table.schema()

def commit_transaction(self) -> Table:
"""Commit the changes to the catalog.

Expand Down Expand Up @@ -965,8 +979,21 @@ def history(self) -> List[SnapshotLogEntry]:
return self.metadata.snapshot_log

def update_schema(self, allow_incompatible_changes: bool = False, case_sensitive: bool = True) -> UpdateSchema:
"""Create a new UpdateSchema to alter the columns of this table.

Returns:
A new UpdateSchema.
"""
return UpdateSchema(self, allow_incompatible_changes=allow_incompatible_changes, case_sensitive=case_sensitive)

def update_snapshot(self) -> UpdateSnapshot:
"""Create a new UpdateSnapshot to produce a new snapshot for the table.

Returns:
A new UpdateSnapshot
"""
return UpdateSnapshot(self)

def name_mapping(self) -> NameMapping:
"""Return the table's field-id NameMapping."""
if name_mapping_json := self.properties.get(TableProperties.DEFAULT_NAME_MAPPING):
Expand All @@ -976,7 +1003,7 @@ def name_mapping(self) -> NameMapping:

def append(self, df: pa.Table) -> None:
"""
Append data to the table.
Shorthand API for appending a PyArrow table to the table.

Args:
df: The Arrow dataframe that will be appended to overwrite the table
Expand All @@ -992,19 +1019,16 @@ def append(self, df: pa.Table) -> None:
if len(self.spec().fields) > 0:
raise ValueError("Cannot write to partitioned tables")

merge = _MergingSnapshotProducer(operation=Operation.APPEND, table=self)

# skip writing data files if the dataframe is empty
if df.shape[0] > 0:
data_files = _dataframe_to_data_files(self, df=df)
for data_file in data_files:
merge.append_data_file(data_file)

merge.commit()
with self.update_snapshot().fast_append() as update_snapshot:
# skip writing data files if the dataframe is empty
if df.shape[0] > 0:
data_files = _dataframe_to_data_files(self, df=df)
for data_file in data_files:
update_snapshot.append_data_file(data_file)

def overwrite(self, df: pa.Table, overwrite_filter: BooleanExpression = ALWAYS_TRUE) -> None:
"""
Overwrite all the data in the table.
Shorthand for overwriting the table with a PyArrow table.

Args:
df: The Arrow dataframe that will be used to overwrite the table
Expand All @@ -1025,18 +1049,12 @@ def overwrite(self, df: pa.Table, overwrite_filter: BooleanExpression = ALWAYS_T
if len(self.spec().fields) > 0:
raise ValueError("Cannot write to partitioned tables")

merge = _MergingSnapshotProducer(
operation=Operation.OVERWRITE if self.current_snapshot() is not None else Operation.APPEND,
table=self,
)

# skip writing data files if the dataframe is empty
if df.shape[0] > 0:
data_files = _dataframe_to_data_files(self, df=df)
for data_file in data_files:
merge.append_data_file(data_file)

merge.commit()
with self.update_snapshot().overwrite() as update_snapshot:
# skip writing data files if the dataframe is empty
if df.shape[0] > 0:
data_files = _dataframe_to_data_files(self, df=df)
for data_file in data_files:
update_snapshot.append_data_file(data_file)

def refs(self) -> Dict[str, SnapshotRef]:
"""Return the snapshot references in the table."""
Expand Down Expand Up @@ -2331,7 +2349,12 @@ def _generate_manifest_list_path(location: str, snapshot_id: int, attempt: int,
return f'{location}/metadata/snap-{snapshot_id}-{attempt}-{commit_uuid}.avro'


def _dataframe_to_data_files(table: Table, df: pa.Table) -> Iterable[DataFile]:
def _dataframe_to_data_files(table: Table, df: pa.Table, file_schema: Optional[Schema] = None) -> Iterable[DataFile]:
"""Convert a PyArrow table into a DataFile.

Returns:
An iterable that supplies datafiles that represent the table.
"""
from pyiceberg.io.pyarrow import write_file

if len(table.spec().fields) > 0:
Expand All @@ -2342,7 +2365,7 @@ def _dataframe_to_data_files(table: Table, df: pa.Table) -> Iterable[DataFile]:

# This is an iter, so we don't have to materialize everything every time
# This will be more relevant when we start doing partitioned writes
yield from write_file(table, iter([WriteTask(write_uuid, next(counter), df)]))
yield from write_file(table, iter([WriteTask(write_uuid, next(counter), df)]), file_schema=file_schema)


class _MergingSnapshotProducer:
Expand All @@ -2352,55 +2375,35 @@ class _MergingSnapshotProducer:
_parent_snapshot_id: Optional[int]
_added_data_files: List[DataFile]
_commit_uuid: uuid.UUID
_transaction: Optional[Transaction]

def __init__(self, operation: Operation, table: Table) -> None:
def __init__(self, operation: Operation, table: Table, transaction: Optional[Transaction] = None) -> None:
self._operation = operation
self._table = table
self._snapshot_id = table.new_snapshot_id()
# Since we only support the main branch for now
self._parent_snapshot_id = snapshot.snapshot_id if (snapshot := self._table.current_snapshot()) else None
self._added_data_files = []
self._commit_uuid = uuid.uuid4()
self._transaction = transaction

def __enter__(self) -> _MergingSnapshotProducer:
"""Start a transaction to update the table."""
return self

def __exit__(self, _: Any, value: Any, traceback: Any) -> None:
"""Close and commit the transaction."""
self.commit()

def append_data_file(self, data_file: DataFile) -> _MergingSnapshotProducer:
self._added_data_files.append(data_file)
return self

def _deleted_entries(self) -> List[ManifestEntry]:
"""To determine if we need to record any deleted entries.

With partial overwrites we have to use the predicate to evaluate
which entries are affected.
"""
if self._operation == Operation.OVERWRITE:
if self._parent_snapshot_id is not None:
previous_snapshot = self._table.snapshot_by_id(self._parent_snapshot_id)
if previous_snapshot is None:
# This should never happen since you cannot overwrite an empty table
raise ValueError(f"Could not find the previous snapshot: {self._parent_snapshot_id}")

executor = ExecutorFactory.get_or_create()

def _get_entries(manifest: ManifestFile) -> List[ManifestEntry]:
return [
ManifestEntry(
status=ManifestEntryStatus.DELETED,
snapshot_id=entry.snapshot_id,
data_sequence_number=entry.data_sequence_number,
file_sequence_number=entry.file_sequence_number,
data_file=entry.data_file,
)
for entry in manifest.fetch_manifest_entry(self._table.io, discard_deleted=True)
if entry.data_file.content == DataFileContent.DATA
]
@abstractmethod
def _deleted_entries(self) -> List[ManifestEntry]: ...

list_of_entries = executor.map(_get_entries, previous_snapshot.manifests(self._table.io))
return list(chain(*list_of_entries))
return []
elif self._operation == Operation.APPEND:
return []
else:
raise ValueError(f"Not implemented for: {self._operation}")
@abstractmethod
def _existing_manifests(self) -> List[ManifestFile]: ...

def _manifests(self) -> List[ManifestFile]:
def _write_added_manifest() -> List[ManifestFile]:
Expand Down Expand Up @@ -2430,7 +2433,7 @@ def _write_added_manifest() -> List[ManifestFile]:
def _write_delete_manifest() -> List[ManifestFile]:
# Check if we need to mark the files as deleted
deleted_entries = self._deleted_entries()
if deleted_entries:
if len(deleted_entries) > 0:
output_file_location = _new_manifest_path(location=self._table.location(), num=1, commit_uuid=self._commit_uuid)
with write_manifest(
format_version=self._table.format_version,
Expand All @@ -2445,32 +2448,11 @@ def _write_delete_manifest() -> List[ManifestFile]:
else:
return []

def _fetch_existing_manifests() -> List[ManifestFile]:
existing_manifests = []

# Add existing manifests
if self._operation == Operation.APPEND and self._parent_snapshot_id is not None:
# In case we want to append, just add the existing manifests
previous_snapshot = self._table.snapshot_by_id(self._parent_snapshot_id)

if previous_snapshot is None:
raise ValueError(f"Snapshot could not be found: {self._parent_snapshot_id}")

for manifest in previous_snapshot.manifests(io=self._table.io):
if (
manifest.has_added_files()
or manifest.has_existing_files()
or manifest.added_snapshot_id == self._snapshot_id
):
existing_manifests.append(manifest)

return existing_manifests

executor = ExecutorFactory.get_or_create()

added_manifests = executor.submit(_write_added_manifest)
delete_manifests = executor.submit(_write_delete_manifest)
existing_manifests = executor.submit(_fetch_existing_manifests)
existing_manifests = executor.submit(self._existing_manifests)

return added_manifests.result() + delete_manifests.result() + existing_manifests.result()

Expand Down Expand Up @@ -2515,10 +2497,107 @@ def commit(self) -> Snapshot:
schema_id=self._table.schema().schema_id,
)

with self._table.transaction() as tx:
tx.add_snapshot(snapshot=snapshot)
tx.set_ref_snapshot(
if self._transaction is not None:
self._transaction.add_snapshot(snapshot=snapshot)
self._transaction.set_ref_snapshot(
snapshot_id=self._snapshot_id, parent_snapshot_id=self._parent_snapshot_id, ref_name="main", type="branch"
)
else:
with self._table.transaction() as tx:
tx.add_snapshot(snapshot=snapshot)
tx.set_ref_snapshot(
snapshot_id=self._snapshot_id, parent_snapshot_id=self._parent_snapshot_id, ref_name="main", type="branch"
)

return snapshot


class FastAppendFiles(_MergingSnapshotProducer):
def _existing_manifests(self) -> List[ManifestFile]:
"""To determine if there are any existing manifest files.

A fast append will add another ManifestFile to the ManifestList.
All the existing manifest files are considered existing.
"""
existing_manifests = []

if self._parent_snapshot_id is not None:
previous_snapshot = self._table.snapshot_by_id(self._parent_snapshot_id)

if previous_snapshot is None:
raise ValueError(f"Snapshot could not be found: {self._parent_snapshot_id}")

for manifest in previous_snapshot.manifests(io=self._table.io):
if manifest.has_added_files() or manifest.has_existing_files() or manifest.added_snapshot_id == self._snapshot_id:
existing_manifests.append(manifest)

return existing_manifests

def _deleted_entries(self) -> List[ManifestEntry]:
"""To determine if we need to record any deleted manifest entries.

In case of an append, nothing is deleted.
"""
return []


class OverwriteFiles(_MergingSnapshotProducer):
def _existing_manifests(self) -> List[ManifestFile]:
"""To determine if there are any existing manifest files.

In the of a full overwrite, all the previous manifests are
considered deleted.
"""
return []

def _deleted_entries(self) -> List[ManifestEntry]:
"""To determine if we need to record any deleted entries.

With a full overwrite all the entries are considered deleted.
With partial overwrites we have to use the predicate to evaluate
which entries are affected.
"""
if self._parent_snapshot_id is not None:
previous_snapshot = self._table.snapshot_by_id(self._parent_snapshot_id)
if previous_snapshot is None:
# This should never happen since you cannot overwrite an empty table
raise ValueError(f"Could not find the previous snapshot: {self._parent_snapshot_id}")

executor = ExecutorFactory.get_or_create()

def _get_entries(manifest: ManifestFile) -> List[ManifestEntry]:
return [
ManifestEntry(
status=ManifestEntryStatus.DELETED,
snapshot_id=entry.snapshot_id,
data_sequence_number=entry.data_sequence_number,
file_sequence_number=entry.file_sequence_number,
data_file=entry.data_file,
)
for entry in manifest.fetch_manifest_entry(self._table.io, discard_deleted=True)
if entry.data_file.content == DataFileContent.DATA
]

list_of_entries = executor.map(_get_entries, previous_snapshot.manifests(self._table.io))
return list(chain(*list_of_entries))
else:
return []


class UpdateSnapshot:
_table: Table
_transaction: Optional[Transaction]

def __init__(self, table: Table, transaction: Optional[Transaction] = None) -> None:
self._table = table
self._transaction = transaction

def fast_append(self) -> FastAppendFiles:
return FastAppendFiles(table=self._table, operation=Operation.APPEND, transaction=self._transaction)

def overwrite(self) -> OverwriteFiles:
return OverwriteFiles(
table=self._table,
operation=Operation.OVERWRITE if self._table.current_snapshot() is not None else Operation.APPEND,
transaction=self._transaction,
)
Loading