Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 34 additions & 26 deletions python/pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,6 @@ def __init__(
requirements: Optional[Tuple[TableRequirement, ...]] = None,
):
self._table = table
self._table_metadata = table.metadata
self._updates = actions or ()
self._requirements = requirements or ()

Expand All @@ -136,7 +135,7 @@ def _append_updates(self, *new_updates: TableUpdate) -> Transaction:
ValueError: When the type of update is not unique.

Returns:
A new AlterTable object with the new updates appended.
Transaction object with the new updates appended.
"""
for new_update in new_updates:
type_new_update = type(new_update)
Expand All @@ -145,6 +144,25 @@ def _append_updates(self, *new_updates: TableUpdate) -> Transaction:
self._updates = self._updates + new_updates
return self

def _append_requirements(self, *new_requirements: TableRequirement) -> Transaction:
"""Appends requirements to the set of staged requirements.

Args:
*new_requirements: Any new requirements.

Raises:
ValueError: When the type of requirement is not unique.

Returns:
Transaction object with the new requirements appended.
"""
for requirement in new_requirements:
type_new_requirement = type(requirement)
if any(type(update) == type_new_requirement for update in self._updates):
raise ValueError(f"Requirements in a single commit need to be unique, duplicate: {type_new_requirement}")
self._requirements = self._requirements + new_requirements
return self

def set_table_version(self, format_version: Literal[1, 2]) -> Transaction:
"""Sets the table to a certain version.

Expand Down Expand Up @@ -205,9 +223,6 @@ def commit_transaction(self) -> Table:
Returns:
The table with the updates applied.
"""
if self._table.metadata != self._table_metadata:
raise RuntimeError("Table metadata refresh is required")

# Strip the catalog name
if len(self._updates) > 0:
response = self._table.catalog._commit_table( # pylint: disable=W0212
Expand Down Expand Up @@ -954,29 +969,22 @@ def allow_incompatible_changes(self) -> UpdateSchema:

def commit(self) -> None:
"""Apply the pending changes and commit."""
if self._transaction is not None:
if self._table.metadata != self._transaction._table_metadata: # pylint: disable=W0212
raise RuntimeError("Table metadata refresh is required")
new_schema = self._apply()
self._transaction._append_updates( # pylint: disable=W0212
AddSchemaUpdate(schema=new_schema, last_column_id=new_schema.highest_field_id)
)
return

# Strip the catalog name
new_schema = self._apply()
table_update_response = self._table.catalog._commit_table( # pylint: disable=W0212
CommitTableRequest(
identifier=self._table.identifier[1:],
updates=[
AddSchemaUpdate(schema=new_schema, last_column_id=new_schema.highest_field_id),
SetCurrentSchemaUpdate(schema_id=-1),
],
)
)
updates = [
AddSchemaUpdate(schema=new_schema, last_column_id=new_schema.highest_field_id),
SetCurrentSchemaUpdate(schema_id=-1),
]
requirements = [AssertCurrentSchemaId(current_schema_id=self._schema.schema_id)]

self._table.metadata = table_update_response.metadata
self._table.metadata_location = table_update_response.metadata_location
if self._transaction is not None:
self._transaction._append_updates(*updates) # pylint: disable=W0212
self._transaction._append_requirements(*requirements) # pylint: disable=W0212
else:
table_update_response = self._table.catalog._commit_table( # pylint: disable=W0212
CommitTableRequest(identifier=self._table.identifier[1:], updates=updates, requirements=requirements)
)
self._table.metadata = table_update_response.metadata
self._table.metadata_location = table_update_response.metadata_location

def _apply(self) -> Schema:
"""Apply the pending changes to the original schema and returns the result.
Expand Down
44 changes: 38 additions & 6 deletions python/tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from pyarrow.fs import S3FileSystem

from pyiceberg.catalog import Catalog, load_catalog
from pyiceberg.exceptions import NoSuchTableError
from pyiceberg.exceptions import CommitFailedException, NoSuchTableError
from pyiceberg.expressions import (
And,
EqualTo,
Expand All @@ -40,8 +40,10 @@
from pyiceberg.table import Table
from pyiceberg.types import (
BooleanType,
DoubleType,
FixedType,
IntegerType,
LongType,
NestedField,
StringType,
TimestampType,
Expand Down Expand Up @@ -395,18 +397,48 @@ def test_schema_evolution_via_transaction(catalog: Catalog) -> None:
NestedField(field_id=2, name="col_fixed", field_type=FixedType(25), required=False),
)

t = catalog.create_table(identifier="default.test_schema_evolution", schema=schema)
tbl = catalog.create_table(identifier="default.test_schema_evolution", schema=schema)

assert t.schema() == schema
assert tbl.schema() == schema

with t.transaction() as tx:
with tbl.transaction() as tx:
tx.update_schema().add_column("col_string", StringType()).commit()

t = catalog.load_table("default.test_schema_evolution")
assert tbl.schema() == Schema(
NestedField(field_id=1, name="col_uuid", field_type=UUIDType(), required=False),
NestedField(field_id=2, name="col_fixed", field_type=FixedType(25), required=False),
NestedField(field_id=3, name="col_string", field_type=StringType(), required=False),
schema_id=1,
)

tbl.update_schema().add_column("col_integer", IntegerType()).commit()

assert t.schema() == Schema(
assert tbl.schema() == Schema(
NestedField(field_id=1, name="col_uuid", field_type=UUIDType(), required=False),
NestedField(field_id=2, name="col_fixed", field_type=FixedType(25), required=False),
NestedField(field_id=3, name="col_string", field_type=StringType(), required=False),
NestedField(field_id=4, name="col_integer", field_type=IntegerType(), required=False),
schema_id=1,
)

with pytest.raises(CommitFailedException) as exc_info:
with tbl.transaction() as tx:
# Start a new update
schema_update = tx.update_schema()

# Do a concurrent update
tbl.update_schema().add_column("col_long", LongType()).commit()

# stage another update in the transaction
schema_update.add_column("col_double", DoubleType()).commit()

assert "Requirement failed: current schema changed: expected id 2 != 3" in str(exc_info.value)

assert tbl.schema() == Schema(
NestedField(field_id=1, name="col_uuid", field_type=UUIDType(), required=False),
NestedField(field_id=2, name="col_fixed", field_type=FixedType(25), required=False),
NestedField(field_id=3, name="col_string", field_type=StringType(), required=False),
NestedField(field_id=4, name="col_integer", field_type=IntegerType(), required=False),
NestedField(field_id=5, name="col_long", field_type=LongType(), required=False),
schema_id=1,
)