Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
203 changes: 193 additions & 10 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@
# under the License.
from __future__ import annotations

import datetime
import itertools
import uuid
from abc import ABC, abstractmethod
from copy import copy
from dataclasses import dataclass
from enum import Enum
from functools import cached_property
from functools import cached_property, singledispatch
from itertools import chain
from typing import (
TYPE_CHECKING,
Expand All @@ -41,6 +42,7 @@

from pydantic import Field, SerializeAsAny
from sortedcontainers import SortedList
from typing_extensions import Annotated

from pyiceberg.exceptions import ResolveError, ValidationError
from pyiceberg.expressions import (
Expand Down Expand Up @@ -69,8 +71,13 @@
promote,
visit,
)
from pyiceberg.table.metadata import INITIAL_SEQUENCE_NUMBER, TableMetadata
from pyiceberg.table.refs import SnapshotRef
from pyiceberg.table.metadata import (
INITIAL_SEQUENCE_NUMBER,
SUPPORTED_TABLE_FORMAT_VERSION,
TableMetadata,
TableMetadataUtil,
)
from pyiceberg.table.refs import MAIN_BRANCH, SnapshotRef
from pyiceberg.table.snapshots import Snapshot, SnapshotLogEntry
from pyiceberg.table.sorting import SortOrder
from pyiceberg.typedef import (
Expand All @@ -90,6 +97,7 @@
StructType,
)
from pyiceberg.utils.concurrent import ExecutorFactory
from pyiceberg.utils.datetime import datetime_to_millis

if TYPE_CHECKING:
import pandas as pd
Expand Down Expand Up @@ -320,9 +328,9 @@ class SetSnapshotRefUpdate(TableUpdate):
ref_name: str = Field(alias="ref-name")
type: Literal["tag", "branch"]
snapshot_id: int = Field(alias="snapshot-id")
max_age_ref_ms: int = Field(alias="max-ref-age-ms")
max_snapshot_age_ms: int = Field(alias="max-snapshot-age-ms")
min_snapshots_to_keep: int = Field(alias="min-snapshots-to-keep")
max_ref_age_ms: Annotated[Optional[int], Field(alias="max-ref-age-ms", default=None)]
max_snapshot_age_ms: Annotated[Optional[int], Field(alias="max-snapshot-age-ms", default=None)]
min_snapshots_to_keep: Annotated[Optional[int], Field(alias="min-snapshots-to-keep", default=None)]


class RemoveSnapshotsUpdate(TableUpdate):
Expand Down Expand Up @@ -350,6 +358,184 @@ class RemovePropertiesUpdate(TableUpdate):
removals: List[str]


class _TableMetadataUpdateContext:
_updates: List[TableUpdate]

def __init__(self) -> None:
self._updates = []

def add_update(self, update: TableUpdate) -> None:
self._updates.append(update)

def is_added_snapshot(self, snapshot_id: int) -> bool:
return any(
update.snapshot.snapshot_id == snapshot_id
for update in self._updates
if update.action == TableUpdateAction.add_snapshot
)

def is_added_schema(self, schema_id: int) -> bool:
return any(
update.schema_.schema_id == schema_id for update in self._updates if update.action == TableUpdateAction.add_schema
)


@singledispatch
def _apply_table_update(update: TableUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata:
"""Apply a table update to the table metadata.

Args:
update: The update to be applied.
base_metadata: The base metadata to be updated.
context: Contains previous updates and other change tracking information in the current transaction.

Returns:
The updated metadata.

"""
raise NotImplementedError(f"Unsupported table update: {update}")


@_apply_table_update.register(UpgradeFormatVersionUpdate)
def _(update: UpgradeFormatVersionUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata:
if update.format_version > SUPPORTED_TABLE_FORMAT_VERSION:
raise ValueError(f"Unsupported table format version: {update.format_version}")
elif update.format_version < base_metadata.format_version:
raise ValueError(f"Cannot downgrade v{base_metadata.format_version} table to v{update.format_version}")
elif update.format_version == base_metadata.format_version:
return base_metadata

updated_metadata_data = copy(base_metadata.model_dump())
updated_metadata_data["format-version"] = update.format_version
Comment on lines +408 to +409
Copy link
Contributor

@Fokko Fokko Nov 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While this is a very safe way of doing the copy, it is also rather expensive since we convert everything to a Python dict, and then create a new object again. Pydantic has the model_copy argument that seems to do what we're looking for:

Suggested change
updated_metadata_data = copy(base_metadata.model_dump())
updated_metadata_data["format-version"] = update.format_version
updated_metadata_data = base_metadata.model_copy(**{"format-version": update.format_version})

We could construct a dict where we add all the changes (for the more complicated updated below), and then call model_copy for each update.

This will make a shallow copy by default (which I think is okay, since the model is immutable).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation! I am a little worried about the immutability of table metadata. I think Pydantic's frozen config does not prevent updates to list, dict, etc. If we make shallow copy of list fields in metadata and later some code mistakenly alter the list (e.g. append something) in the updated metadata, the effect will be populated to the base_metadata too and it may be hard to detect. Do you think this might be a problem?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that case, we can be cautious and set deep=true. I would love to see some tests that validate the behavior. Those should be easy to add.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion! I've just implemented the suggested change on my end, but I'm still in the process of building the tests for shallow vs deep copy. Given that the current PR already contains lots of change, do you think it might be a good idea to make the model_copy transfer in a separate, follow-up PR?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good @HonahX. I've created a new issue here: #179


context.add_update(update)
return TableMetadataUtil.parse_obj(updated_metadata_data)


@_apply_table_update.register(AddSchemaUpdate)
def _(update: AddSchemaUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata:
if update.last_column_id < base_metadata.last_column_id:
raise ValueError(f"Invalid last column id {update.last_column_id}, must be >= {base_metadata.last_column_id}")

updated_metadata_data = copy(base_metadata.model_dump())
updated_metadata_data["last-column-id"] = update.last_column_id
updated_metadata_data["schemas"].append(update.schema_.model_dump())

context.add_update(update)
return TableMetadataUtil.parse_obj(updated_metadata_data)


@_apply_table_update.register(SetCurrentSchemaUpdate)
def _(update: SetCurrentSchemaUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata:
new_schema_id = update.schema_id
if new_schema_id == -1:
# The last added schema should be in base_metadata.schemas at this point
new_schema_id = max(schema.schema_id for schema in base_metadata.schemas)
if not context.is_added_schema(new_schema_id):
raise ValueError("Cannot set current schema to last added schema when no schema has been added")

if new_schema_id == base_metadata.current_schema_id:
return base_metadata

schema = base_metadata.schema_by_id(new_schema_id)
if schema is None:
raise ValueError(f"Schema with id {new_schema_id} does not exist")

updated_metadata_data = copy(base_metadata.model_dump())
updated_metadata_data["current-schema-id"] = new_schema_id

context.add_update(update)
return TableMetadataUtil.parse_obj(updated_metadata_data)


@_apply_table_update.register(AddSnapshotUpdate)
def _(update: AddSnapshotUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata:
if len(base_metadata.schemas) == 0:
raise ValueError("Attempting to add a snapshot before a schema is added")
elif len(base_metadata.partition_specs) == 0:
raise ValueError("Attempting to add a snapshot before a partition spec is added")
elif len(base_metadata.sort_orders) == 0:
raise ValueError("Attempting to add a snapshot before a sort order is added")
elif base_metadata.snapshot_by_id(update.snapshot.snapshot_id) is not None:
raise ValueError(f"Snapshot with id {update.snapshot.snapshot_id} already exists")
elif (
base_metadata.format_version == 2
and update.snapshot.sequence_number is not None
and update.snapshot.sequence_number <= base_metadata.last_sequence_number
and update.snapshot.parent_snapshot_id is not None
):
raise ValueError(
f"Cannot add snapshot with sequence number {update.snapshot.sequence_number} "
f"older than last sequence number {base_metadata.last_sequence_number}"
)

updated_metadata_data = copy(base_metadata.model_dump())
updated_metadata_data["last-updated-ms"] = update.snapshot.timestamp_ms
updated_metadata_data["last-sequence-number"] = update.snapshot.sequence_number
updated_metadata_data["snapshots"].append(update.snapshot.model_dump())
context.add_update(update)
return TableMetadataUtil.parse_obj(updated_metadata_data)


@_apply_table_update.register(SetSnapshotRefUpdate)
def _(update: SetSnapshotRefUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata:
snapshot_ref = SnapshotRef(
snapshot_id=update.snapshot_id,
snapshot_ref_type=update.type,
min_snapshots_to_keep=update.min_snapshots_to_keep,
max_snapshot_age_ms=update.max_snapshot_age_ms,
max_ref_age_ms=update.max_ref_age_ms,
)

existing_ref = base_metadata.refs.get(update.ref_name)
if existing_ref is not None and existing_ref == snapshot_ref:
return base_metadata

snapshot = base_metadata.snapshot_by_id(snapshot_ref.snapshot_id)
if snapshot is None:
raise ValueError(f"Cannot set {snapshot_ref.ref_name} to unknown snapshot {snapshot_ref.snapshot_id}")

update_metadata_data = copy(base_metadata.model_dump())
update_last_updated_ms = True
if context.is_added_snapshot(snapshot_ref.snapshot_id):
update_metadata_data["last-updated-ms"] = snapshot.timestamp_ms
update_last_updated_ms = False

if update.ref_name == MAIN_BRANCH:
update_metadata_data["current-snapshot-id"] = snapshot_ref.snapshot_id
if update_last_updated_ms:
update_metadata_data["last-updated-ms"] = datetime_to_millis(datetime.datetime.now().astimezone())
update_metadata_data["snapshot-log"].append(
SnapshotLogEntry(
snapshot_id=snapshot_ref.snapshot_id,
timestamp_ms=update_metadata_data["last-updated-ms"],
).model_dump()
)

update_metadata_data["refs"][update.ref_name] = snapshot_ref.model_dump()
context.add_update(update)
return TableMetadataUtil.parse_obj(update_metadata_data)


def update_table_metadata(base_metadata: TableMetadata, updates: Tuple[TableUpdate, ...]) -> TableMetadata:
"""Update the table metadata with the given updates in one transaction.

Args:
base_metadata: The base metadata to be updated.
updates: The updates in one transaction.

Returns:
The metadata with the updates applied.
"""
context = _TableMetadataUpdateContext()
new_metadata = base_metadata

for update in updates:
new_metadata = _apply_table_update(update, new_metadata, context)

return new_metadata


class TableRequirement(IcebergBaseModel):
type: str

Expand Down Expand Up @@ -552,10 +738,7 @@ def current_snapshot(self) -> Optional[Snapshot]:

def snapshot_by_id(self, snapshot_id: int) -> Optional[Snapshot]:
"""Get the snapshot of this table with the given id, or None if there is no matching snapshot."""
try:
return next(snapshot for snapshot in self.metadata.snapshots if snapshot.snapshot_id == snapshot_id)
except StopIteration:
return None
return self.metadata.snapshot_by_id(snapshot_id)

def snapshot_by_name(self, name: str) -> Optional[Snapshot]:
"""Return the snapshot referenced by the given name or null if no such reference exists."""
Expand Down
10 changes: 10 additions & 0 deletions pyiceberg/table/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@
INITIAL_SPEC_ID = 0
DEFAULT_SCHEMA_ID = 0

SUPPORTED_TABLE_FORMAT_VERSION = 2


def cleanup_snapshot_id(data: Dict[str, Any]) -> Dict[str, Any]:
"""Run before validation."""
Expand Down Expand Up @@ -216,6 +218,14 @@ class TableMetadataCommonFields(IcebergBaseModel):
There is always a main branch reference pointing to the
current-snapshot-id even if the refs map is null."""

def snapshot_by_id(self, snapshot_id: int) -> Optional[Snapshot]:
"""Get the snapshot by snapshot_id."""
return next((snapshot for snapshot in self.snapshots if snapshot.snapshot_id == snapshot_id), None)

def schema_by_id(self, schema_id: int) -> Optional[Schema]:
"""Get the schema by schema_id."""
return next((schema for schema in self.schemas if schema.schema_id == schema_id), None)


class TableMetadataV1(TableMetadataCommonFields, IcebergBaseModel):
"""Represents version 1 of the Table Metadata.
Expand Down
22 changes: 18 additions & 4 deletions pyiceberg/table/refs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
from enum import Enum
from typing import Optional

from pydantic import Field
from pydantic import Field, model_validator
from typing_extensions import Annotated

from pyiceberg.exceptions import ValidationError
from pyiceberg.typedef import IcebergBaseModel

MAIN_BRANCH = "main"
Expand All @@ -36,6 +38,18 @@ def __repr__(self) -> str:
class SnapshotRef(IcebergBaseModel):
snapshot_id: int = Field(alias="snapshot-id")
snapshot_ref_type: SnapshotRefType = Field(alias="type")
min_snapshots_to_keep: Optional[int] = Field(alias="min-snapshots-to-keep", default=None)
max_snapshot_age_ms: Optional[int] = Field(alias="max-snapshot-age-ms", default=None)
max_ref_age_ms: Optional[int] = Field(alias="max-ref-age-ms", default=None)
min_snapshots_to_keep: Annotated[Optional[int], Field(alias="min-snapshots-to-keep", default=None, gt=0)]
max_snapshot_age_ms: Annotated[Optional[int], Field(alias="max-snapshot-age-ms", default=None, gt=0)]
max_ref_age_ms: Annotated[Optional[int], Field(alias="max-ref-age-ms", default=None, gt=0)]

@model_validator(mode='after')
def check_min_snapshots_to_keep(self) -> 'SnapshotRef':
if self.min_snapshots_to_keep is not None and self.snapshot_ref_type == SnapshotRefType.TAG:
raise ValidationError("Tags do not support setting minSnapshotsToKeep")
return self

@model_validator(mode='after')
def check_max_snapshot_age_ms(self) -> 'SnapshotRef':
if self.max_snapshot_age_ms is not None and self.snapshot_ref_type == SnapshotRefType.TAG:
raise ValidationError("Tags do not support setting maxSnapshotAgeMs")
return self
42 changes: 40 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
from pyiceberg.schema import Accessor, Schema
from pyiceberg.serializers import ToOutputFile
from pyiceberg.table import FileScanTask, Table
from pyiceberg.table.metadata import TableMetadataV2
from pyiceberg.table.metadata import TableMetadataV1, TableMetadataV2
from pyiceberg.typedef import UTF8
from pyiceberg.types import (
BinaryType,
Expand Down Expand Up @@ -354,6 +354,32 @@ def all_avro_types() -> Dict[str, Any]:
}


EXAMPLE_TABLE_METADATA_V1 = {
"format-version": 1,
"table-uuid": "d20125c8-7284-442c-9aea-15fee620737c",
"location": "s3://bucket/test/location",
"last-updated-ms": 1602638573874,
"last-column-id": 3,
"schema": {
"type": "struct",
"fields": [
{"id": 1, "name": "x", "required": True, "type": "long"},
{"id": 2, "name": "y", "required": True, "type": "long", "doc": "comment"},
{"id": 3, "name": "z", "required": True, "type": "long"},
],
},
"partition-spec": [{"name": "x", "transform": "identity", "source-id": 1, "field-id": 1000}],
"properties": {},
"current-snapshot-id": -1,
"snapshots": [{"snapshot-id": 1925, "timestamp-ms": 1602638573822}],
}


@pytest.fixture(scope="session")
def example_table_metadata_v1() -> Dict[str, Any]:
return EXAMPLE_TABLE_METADATA_V1


EXAMPLE_TABLE_METADATA_WITH_SNAPSHOT_V1 = {
"format-version": 1,
"table-uuid": "b55d9dda-6561-423a-8bfc-787980ce421f",
Expand Down Expand Up @@ -1780,7 +1806,19 @@ def example_task(data_file: str) -> FileScanTask:


@pytest.fixture
def table(example_table_metadata_v2: Dict[str, Any]) -> Table:
def table_v1(example_table_metadata_v1: Dict[str, Any]) -> Table:
table_metadata = TableMetadataV1(**example_table_metadata_v1)
return Table(
identifier=("database", "table"),
metadata=table_metadata,
metadata_location=f"{table_metadata.location}/uuid.metadata.json",
io=load_file_io(),
catalog=NoopCatalog("NoopCatalog"),
)


@pytest.fixture
def table_v2(example_table_metadata_v2: Dict[str, Any]) -> Table:
table_metadata = TableMetadataV2(**example_table_metadata_v2)
return Table(
identifier=("database", "table"),
Expand Down
Loading