diff --git a/python/pyiceberg/table/__init__.py b/python/pyiceberg/table/__init__.py index c8668569034e..3d4e5f7d2862 100644 --- a/python/pyiceberg/table/__init__.py +++ b/python/pyiceberg/table/__init__.py @@ -111,7 +111,6 @@ def __init__( requirements: Optional[Tuple[TableRequirement, ...]] = None, ): self._table = table - self._table_metadata = table.metadata self._updates = actions or () self._requirements = requirements or () @@ -136,7 +135,7 @@ def _append_updates(self, *new_updates: TableUpdate) -> Transaction: ValueError: When the type of update is not unique. Returns: - A new AlterTable object with the new updates appended. + Transaction object with the new updates appended. """ for new_update in new_updates: type_new_update = type(new_update) @@ -145,6 +144,25 @@ def _append_updates(self, *new_updates: TableUpdate) -> Transaction: self._updates = self._updates + new_updates return self + def _append_requirements(self, *new_requirements: TableRequirement) -> Transaction: + """Appends requirements to the set of staged requirements. + + Args: + *new_requirements: Any new requirements. + + Raises: + ValueError: When the type of requirement is not unique. + + Returns: + Transaction object with the new requirements appended. + """ + for requirement in new_requirements: + type_new_requirement = type(requirement) + if any(type(update) == type_new_requirement for update in self._updates): + raise ValueError(f"Requirements in a single commit need to be unique, duplicate: {type_new_requirement}") + self._requirements = self._requirements + new_requirements + return self + def set_table_version(self, format_version: Literal[1, 2]) -> Transaction: """Sets the table to a certain version. @@ -205,9 +223,6 @@ def commit_transaction(self) -> Table: Returns: The table with the updates applied. """ - if self._table.metadata != self._table_metadata: - raise RuntimeError("Table metadata refresh is required") - # Strip the catalog name if len(self._updates) > 0: response = self._table.catalog._commit_table( # pylint: disable=W0212 @@ -954,29 +969,22 @@ def allow_incompatible_changes(self) -> UpdateSchema: def commit(self) -> None: """Apply the pending changes and commit.""" - if self._transaction is not None: - if self._table.metadata != self._transaction._table_metadata: # pylint: disable=W0212 - raise RuntimeError("Table metadata refresh is required") - new_schema = self._apply() - self._transaction._append_updates( # pylint: disable=W0212 - AddSchemaUpdate(schema=new_schema, last_column_id=new_schema.highest_field_id) - ) - return - - # Strip the catalog name new_schema = self._apply() - table_update_response = self._table.catalog._commit_table( # pylint: disable=W0212 - CommitTableRequest( - identifier=self._table.identifier[1:], - updates=[ - AddSchemaUpdate(schema=new_schema, last_column_id=new_schema.highest_field_id), - SetCurrentSchemaUpdate(schema_id=-1), - ], - ) - ) + updates = [ + AddSchemaUpdate(schema=new_schema, last_column_id=new_schema.highest_field_id), + SetCurrentSchemaUpdate(schema_id=-1), + ] + requirements = [AssertCurrentSchemaId(current_schema_id=self._schema.schema_id)] - self._table.metadata = table_update_response.metadata - self._table.metadata_location = table_update_response.metadata_location + if self._transaction is not None: + self._transaction._append_updates(*updates) # pylint: disable=W0212 + self._transaction._append_requirements(*requirements) # pylint: disable=W0212 + else: + table_update_response = self._table.catalog._commit_table( # pylint: disable=W0212 + CommitTableRequest(identifier=self._table.identifier[1:], updates=updates, requirements=requirements) + ) + self._table.metadata = table_update_response.metadata + self._table.metadata_location = table_update_response.metadata_location def _apply(self) -> Schema: """Apply the pending changes to the original schema and returns the result. diff --git a/python/tests/test_integration.py b/python/tests/test_integration.py index fce3646230b2..acd694677463 100644 --- a/python/tests/test_integration.py +++ b/python/tests/test_integration.py @@ -25,7 +25,7 @@ from pyarrow.fs import S3FileSystem from pyiceberg.catalog import Catalog, load_catalog -from pyiceberg.exceptions import NoSuchTableError +from pyiceberg.exceptions import CommitFailedException, NoSuchTableError from pyiceberg.expressions import ( And, EqualTo, @@ -40,8 +40,10 @@ from pyiceberg.table import Table from pyiceberg.types import ( BooleanType, + DoubleType, FixedType, IntegerType, + LongType, NestedField, StringType, TimestampType, @@ -395,18 +397,48 @@ def test_schema_evolution_via_transaction(catalog: Catalog) -> None: NestedField(field_id=2, name="col_fixed", field_type=FixedType(25), required=False), ) - t = catalog.create_table(identifier="default.test_schema_evolution", schema=schema) + tbl = catalog.create_table(identifier="default.test_schema_evolution", schema=schema) - assert t.schema() == schema + assert tbl.schema() == schema - with t.transaction() as tx: + with tbl.transaction() as tx: tx.update_schema().add_column("col_string", StringType()).commit() - t = catalog.load_table("default.test_schema_evolution") + assert tbl.schema() == Schema( + NestedField(field_id=1, name="col_uuid", field_type=UUIDType(), required=False), + NestedField(field_id=2, name="col_fixed", field_type=FixedType(25), required=False), + NestedField(field_id=3, name="col_string", field_type=StringType(), required=False), + schema_id=1, + ) + + tbl.update_schema().add_column("col_integer", IntegerType()).commit() - assert t.schema() == Schema( + assert tbl.schema() == Schema( + NestedField(field_id=1, name="col_uuid", field_type=UUIDType(), required=False), + NestedField(field_id=2, name="col_fixed", field_type=FixedType(25), required=False), + NestedField(field_id=3, name="col_string", field_type=StringType(), required=False), + NestedField(field_id=4, name="col_integer", field_type=IntegerType(), required=False), + schema_id=1, + ) + + with pytest.raises(CommitFailedException) as exc_info: + with tbl.transaction() as tx: + # Start a new update + schema_update = tx.update_schema() + + # Do a concurrent update + tbl.update_schema().add_column("col_long", LongType()).commit() + + # stage another update in the transaction + schema_update.add_column("col_double", DoubleType()).commit() + + assert "Requirement failed: current schema changed: expected id 2 != 3" in str(exc_info.value) + + assert tbl.schema() == Schema( NestedField(field_id=1, name="col_uuid", field_type=UUIDType(), required=False), NestedField(field_id=2, name="col_fixed", field_type=FixedType(25), required=False), NestedField(field_id=3, name="col_string", field_type=StringType(), required=False), + NestedField(field_id=4, name="col_integer", field_type=IntegerType(), required=False), + NestedField(field_id=5, name="col_long", field_type=LongType(), required=False), schema_id=1, )