diff --git a/python/mkdocs/docs/api.md b/python/mkdocs/docs/api.md index d3b8fceee5c5..f0b2873c038e 100644 --- a/python/mkdocs/docs/api.md +++ b/python/mkdocs/docs/api.md @@ -146,6 +146,38 @@ catalog.create_table( ) ``` +### Update table schema + +Add new columns through the `Transaction` or `UpdateSchema` API: + +Use the Transaction API: + +```python +with table.transaction() as transaction: + transaction.update_schema().add_column("x", IntegerType(), "doc").commit() +``` + +Or, without a context manager: + +```python +transaction = table.transaction() +transaction.update_schema().add_column("x", IntegerType(), "doc").commit() +transaction.commit_transaction() +``` + +Or, use the UpdateSchema API directly: + +```python +with table.update_schema() as update: + update.add_column("x", IntegerType(), "doc") +``` + +Or, without a context manager: + +```python +table.update_schema().add_column("x", IntegerType(), "doc").commit() +``` + ### Update table properties Set and remove properties through the `Transaction` API: diff --git a/python/pyiceberg/schema.py b/python/pyiceberg/schema.py index 74232d0b7b17..5064d07174a6 100644 --- a/python/pyiceberg/schema.py +++ b/python/pyiceberg/schema.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=W0511 +from __future__ import annotations + import itertools from abc import ABC, abstractmethod from dataclasses import dataclass @@ -145,7 +147,7 @@ def _lazy_id_to_name(self) -> Dict[int, str]: return index_name_by_id(self) @cached_property - def _lazy_id_to_accessor(self) -> Dict[int, "Accessor"]: + def _lazy_id_to_accessor(self) -> Dict[int, Accessor]: """Returns an index of field ID to accessor. This is calculated once when called for the first time. Subsequent calls to this method will use a cached index. @@ -201,7 +203,7 @@ def find_type(self, name_or_id: Union[str, int], case_sensitive: bool = True) -> @property def highest_field_id(self) -> int: - return visit(self.as_struct(), _FindLastFieldId()) + return max(self._lazy_id_to_name.keys(), default=0) def find_column_name(self, column_id: int) -> Optional[str]: """Find a column name given a column ID. @@ -226,7 +228,7 @@ def column_names(self) -> List[str]: """ return list(self._lazy_id_to_name.values()) - def accessor_for_field(self, field_id: int) -> "Accessor": + def accessor_for_field(self, field_id: int) -> Accessor: """Find a schema position accessor given a field ID. Args: @@ -243,7 +245,7 @@ def accessor_for_field(self, field_id: int) -> "Accessor": return self._lazy_id_to_accessor[field_id] - def select(self, *names: str, case_sensitive: bool = True) -> "Schema": + def select(self, *names: str, case_sensitive: bool = True) -> Schema: """Return a new schema instance pruned to a subset of columns. Args: @@ -682,7 +684,7 @@ class Accessor: """An accessor for a specific position in a container that implements the StructProtocol.""" position: int - inner: Optional["Accessor"] = None + inner: Optional[Accessor] = None def __str__(self) -> str: """Returns the string representation of the Accessor class.""" @@ -766,7 +768,7 @@ def _(obj: MapType, visitor: SchemaVisitor[T]) -> T: visitor.before_map_value(obj.value_field) value_result = visit(obj.value_type, visitor) - visitor.after_list_element(obj.value_field) + visitor.after_map_value(obj.value_field) return visitor.map(obj, key_result, value_result) @@ -890,6 +892,22 @@ def __init__(self) -> None: self._field_names: List[str] = [] self._short_field_names: List[str] = [] + def before_map_key(self, key: NestedField) -> None: + self.before_field(key) + + def after_map_key(self, key: NestedField) -> None: + self.after_field(key) + + def before_map_value(self, value: NestedField) -> None: + if not isinstance(value.field_type, StructType): + self._short_field_names.append(value.name) + self._field_names.append(value.name) + + def after_map_value(self, value: NestedField) -> None: + if not isinstance(value.field_type, StructType): + self._short_field_names.pop() + self._field_names.pop() + def before_list_element(self, element: NestedField) -> None: """Short field names omit element when the element is a StructType.""" if not isinstance(element.field_type, StructType): @@ -1082,45 +1100,23 @@ def build_position_accessors(schema_or_type: Union[Schema, IcebergType]) -> Dict return visit(schema_or_type, _BuildPositionAccessors()) -class _FindLastFieldId(SchemaVisitor[int]): - """Traverses the schema to get the highest field-id.""" - - def schema(self, schema: Schema, struct_result: int) -> int: - return struct_result - - def struct(self, struct: StructType, field_results: List[int]) -> int: - return max(field_results) - - def field(self, field: NestedField, field_result: int) -> int: - return max(field.field_id, field_result) - - def list(self, list_type: ListType, element_result: int) -> int: - return element_result - - def map(self, map_type: MapType, key_result: int, value_result: int) -> int: - return max(key_result, value_result) - - def primitive(self, primitive: PrimitiveType) -> int: - return 0 - - -def assign_fresh_schema_ids(schema: Schema) -> Schema: +def assign_fresh_schema_ids(schema_or_type: Union[Schema, IcebergType], next_id: Optional[Callable[[], int]] = None) -> Schema: """Traverses the schema, and sets new IDs.""" - return pre_order_visit(schema, _SetFreshIDs()) + return pre_order_visit(schema_or_type, _SetFreshIDs(next_id_func=next_id)) class _SetFreshIDs(PreOrderSchemaVisitor[IcebergType]): """Traverses the schema and assigns monotonically increasing ids.""" - counter: itertools.count # type: ignore reserved_ids: Dict[int, int] - def __init__(self, start: int = 1) -> None: - self.counter = itertools.count(start) + def __init__(self, next_id_func: Optional[Callable[[], int]] = None) -> None: self.reserved_ids = {} + counter = itertools.count(1) + self.next_id_func = next_id_func if next_id_func is not None else lambda: next(counter) def _get_and_increment(self) -> int: - return next(self.counter) + return self.next_id_func() def schema(self, schema: Schema, struct_result: Callable[[], StructType]) -> Schema: # First we keep the original identifier_field_ids here, we remap afterwards diff --git a/python/pyiceberg/table/__init__.py b/python/pyiceberg/table/__init__.py index 52479c29ca20..3d4e5f7d2862 100644 --- a/python/pyiceberg/table/__init__.py +++ b/python/pyiceberg/table/__init__.py @@ -16,6 +16,7 @@ # under the License. from __future__ import annotations +import itertools from abc import ABC, abstractmethod from dataclasses import dataclass from enum import Enum @@ -58,7 +59,13 @@ ManifestFile, ) from pyiceberg.partitioning import PartitionSpec -from pyiceberg.schema import Schema +from pyiceberg.schema import ( + Schema, + SchemaVisitor, + assign_fresh_schema_ids, + index_by_name, + visit, +) from pyiceberg.table.metadata import INITIAL_SEQUENCE_NUMBER, TableMetadata from pyiceberg.table.snapshots import Snapshot, SnapshotLogEntry from pyiceberg.table.sorting import SortOrder @@ -69,6 +76,14 @@ KeyDefaultDict, Properties, ) +from pyiceberg.types import ( + IcebergType, + ListType, + MapType, + NestedField, + PrimitiveType, + StructType, +) from pyiceberg.utils.concurrent import ExecutorFactory if TYPE_CHECKING: @@ -81,6 +96,7 @@ ALWAYS_TRUE = AlwaysTrue() +TABLE_ROOT_ID = -1 class Transaction: @@ -119,7 +135,7 @@ def _append_updates(self, *new_updates: TableUpdate) -> Transaction: ValueError: When the type of update is not unique. Returns: - A new AlterTable object with the new updates appended. + Transaction object with the new updates appended. """ for new_update in new_updates: type_new_update = type(new_update) @@ -128,6 +144,25 @@ def _append_updates(self, *new_updates: TableUpdate) -> Transaction: self._updates = self._updates + new_updates return self + def _append_requirements(self, *new_requirements: TableRequirement) -> Transaction: + """Appends requirements to the set of staged requirements. + + Args: + *new_requirements: Any new requirements. + + Raises: + ValueError: When the type of requirement is not unique. + + Returns: + Transaction object with the new requirements appended. + """ + for requirement in new_requirements: + type_new_requirement = type(requirement) + if any(type(update) == type_new_requirement for update in self._updates): + raise ValueError(f"Requirements in a single commit need to be unique, duplicate: {type_new_requirement}") + self._requirements = self._requirements + new_requirements + return self + def set_table_version(self, format_version: Literal[1, 2]) -> Transaction: """Sets the table to a certain version. @@ -152,6 +187,14 @@ def set_properties(self, **updates: str) -> Transaction: """ return self._append_updates(SetPropertiesUpdate(updates=updates)) + def update_schema(self) -> UpdateSchema: + """Create a new UpdateSchema to alter the columns of this table. + + Returns: + A new UpdateSchema. + """ + return UpdateSchema(self._table.schema(), self._table, self) + def remove_properties(self, *removals: str) -> Transaction: """Removes properties. @@ -227,6 +270,8 @@ class UpgradeFormatVersionUpdate(TableUpdate): class AddSchemaUpdate(TableUpdate): action: TableUpdateAction = TableUpdateAction.add_schema schema_: Schema = Field(alias="schema") + # This field is required: https://github.com/apache/iceberg/pull/7445 + last_column_id: int = Field(alias="last-column-id") class SetCurrentSchemaUpdate(TableUpdate): @@ -307,13 +352,13 @@ class TableRequirement(IcebergBaseModel): class AssertCreate(TableRequirement): """The table must not already exist; used for create transactions.""" - type: Literal["assert-create"] + type: Literal["assert-create"] = Field(default="assert-create") class AssertTableUUID(TableRequirement): """The table UUID must match the requirement's `uuid`.""" - type: Literal["assert-table-uuid"] + type: Literal["assert-table-uuid"] = Field(default="assert-table-uuid") uuid: str @@ -323,7 +368,7 @@ class AssertRefSnapshotId(TableRequirement): if `snapshot-id` is `null` or missing, the ref must not already exist. """ - type: Literal["assert-ref-snapshot-id"] + type: Literal["assert-ref-snapshot-id"] = Field(default="assert-ref-snapshot-id") ref: str snapshot_id: int = Field(..., alias="snapshot-id") @@ -331,35 +376,35 @@ class AssertRefSnapshotId(TableRequirement): class AssertLastAssignedFieldId(TableRequirement): """The table's last assigned column id must match the requirement's `last-assigned-field-id`.""" - type: Literal["assert-last-assigned-field-id"] + type: Literal["assert-last-assigned-field-id"] = Field(default="assert-last-assigned-field-id") last_assigned_field_id: int = Field(..., alias="last-assigned-field-id") class AssertCurrentSchemaId(TableRequirement): """The table's current schema id must match the requirement's `current-schema-id`.""" - type: Literal["assert-current-schema-id"] + type: Literal["assert-current-schema-id"] = Field(default="assert-current-schema-id") current_schema_id: int = Field(..., alias="current-schema-id") class AssertLastAssignedPartitionId(TableRequirement): """The table's last assigned partition id must match the requirement's `last-assigned-partition-id`.""" - type: Literal["assert-last-assigned-partition-id"] + type: Literal["assert-last-assigned-partition-id"] = Field(default="assert-last-assigned-partition-id") last_assigned_partition_id: int = Field(..., alias="last-assigned-partition-id") class AssertDefaultSpecId(TableRequirement): """The table's default spec id must match the requirement's `default-spec-id`.""" - type: Literal["assert-default-spec-id"] + type: Literal["assert-default-spec-id"] = Field(default="assert-default-spec-id") default_spec_id: int = Field(..., alias="default-spec-id") class AssertDefaultSortOrderId(TableRequirement): """The table's default sort order id must match the requirement's `default-sort-order-id`.""" - type: Literal["assert-default-sort-order-id"] + type: Literal["assert-default-sort-order-id"] = Field(default="assert-default-sort-order-id") default_sort_order_id: int = Field(..., alias="default-sort-order-id") @@ -482,6 +527,9 @@ def history(self) -> List[SnapshotLogEntry]: """Get the snapshot history of this table.""" return self.metadata.snapshot_log + def update_schema(self) -> UpdateSchema: + return UpdateSchema(self.schema(), self) + def __eq__(self, other: Any) -> bool: """Returns the equality of two instances of the Table class.""" return ( @@ -839,3 +887,253 @@ def to_ray(self) -> ray.data.dataset.Dataset: import ray return ray.data.from_arrow(self.to_arrow()) + + +class UpdateSchema: + _table: Table + _schema: Schema + _last_column_id: itertools.count[int] + _identifier_field_names: List[str] + _adds: Dict[int, List[NestedField]] + _added_name_to_id: Dict[str, int] + _id_to_parent: Dict[int, str] + _allow_incompatible_changes: bool + _case_sensitive: bool + _transaction: Optional[Transaction] + + def __init__(self, schema: Schema, table: Table, transaction: Optional[Transaction] = None): + self._table = table + self._schema = schema + self._last_column_id = itertools.count(schema.highest_field_id + 1) + self._identifier_field_names = schema.column_names + self._adds = {} + self._added_name_to_id = {} + self._id_to_parent = {} + self._allow_incompatible_changes = False + self._case_sensitive = True + self._transaction = transaction + + def __exit__(self, _: Any, value: Any, traceback: Any) -> None: + """Closes and commits the change.""" + return self.commit() + + def __enter__(self) -> UpdateSchema: + """Update the table.""" + return self + + def case_sensitive(self, case_sensitive: bool) -> UpdateSchema: + """Determines if the case of schema needs to be considered when comparing column names. + + Args: + case_sensitive: When false case is not considered in column name comparisons. + + Returns: + This for method chaining + """ + self._case_sensitive = case_sensitive + return self + + def add_column( + self, name: str, type_var: IcebergType, doc: Optional[str] = None, parent: Optional[str] = None, required: bool = False + ) -> UpdateSchema: + """Add a new column to a nested struct or Add a new top-level column. + + Args: + name: Name for the new column. + type_var: Type for the new column. + doc: Documentation string for the new column. + parent: Name of the parent struct to the column will be added to. + required: Whether the new column is required. + + Returns: + This for method chaining + """ + if "." in name: + raise ValueError(f"Cannot add column with ambiguous name: {name}") + + if required and not self._allow_incompatible_changes: + # Table format version 1 and 2 cannot add required column because there is no initial value + raise ValueError(f"Incompatible change: cannot add required column: {name}") + + self._internal_add_column(parent, name, not required, type_var, doc) + return self + + def allow_incompatible_changes(self) -> UpdateSchema: + """Allow incompatible changes to the schema. + + Returns: + This for method chaining + """ + self._allow_incompatible_changes = True + return self + + def commit(self) -> None: + """Apply the pending changes and commit.""" + new_schema = self._apply() + updates = [ + AddSchemaUpdate(schema=new_schema, last_column_id=new_schema.highest_field_id), + SetCurrentSchemaUpdate(schema_id=-1), + ] + requirements = [AssertCurrentSchemaId(current_schema_id=self._schema.schema_id)] + + if self._transaction is not None: + self._transaction._append_updates(*updates) # pylint: disable=W0212 + self._transaction._append_requirements(*requirements) # pylint: disable=W0212 + else: + table_update_response = self._table.catalog._commit_table( # pylint: disable=W0212 + CommitTableRequest(identifier=self._table.identifier[1:], updates=updates, requirements=requirements) + ) + self._table.metadata = table_update_response.metadata + self._table.metadata_location = table_update_response.metadata_location + + def _apply(self) -> Schema: + """Apply the pending changes to the original schema and returns the result. + + Returns: + the result Schema when all pending updates are applied + """ + return _apply_changes(self._schema, self._adds, self._identifier_field_names) + + def _internal_add_column( + self, parent: Optional[str], name: str, is_optional: bool, type_var: IcebergType, doc: Optional[str] + ) -> None: + full_name: str = name + parent_id: int = TABLE_ROOT_ID + + exist_field: Optional[NestedField] = None + if parent: + parent_field = self._schema.find_field(parent, self._case_sensitive) + parent_type = parent_field.field_type + if isinstance(parent_type, MapType): + parent_field = parent_type.value_field + elif isinstance(parent_type, ListType): + parent_field = parent_type.element_field + + if not parent_field.field_type.is_struct: + raise ValueError(f"Cannot add column to non-struct type: {parent}") + + parent_id = parent_field.field_id + + try: + exist_field = self._schema.find_field(parent + "." + name, self._case_sensitive) + except ValueError: + pass + + if exist_field: + raise ValueError(f"Cannot add column, name already exists: {parent}.{name}") + + full_name = parent_field.name + "." + name + + else: + try: + exist_field = self._schema.find_field(name, self._case_sensitive) + except ValueError: + pass + + if exist_field: + raise ValueError(f"Cannot add column, name already exists: {name}") + + # assign new IDs in order + new_id = self.assign_new_column_id() + + # update tracking for moves + self._added_name_to_id[full_name] = new_id + + new_type = assign_fresh_schema_ids(type_var, self.assign_new_column_id) + field = NestedField(new_id, name, new_type, not is_optional, doc) + + self._adds.setdefault(parent_id, []).append(field) + + def assign_new_column_id(self) -> int: + return next(self._last_column_id) + + +def _apply_changes(schema_: Schema, adds: Dict[int, List[NestedField]], identifier_field_names: List[str]) -> Schema: + struct = visit(schema_, _ApplyChanges(adds)) + name_to_id: Dict[str, int] = index_by_name(struct) + for name in identifier_field_names: + if name not in name_to_id: + raise ValueError(f"Cannot add field {name} as an identifier field: not found in current schema or added columns") + + return Schema(*struct.fields) + + +class _ApplyChanges(SchemaVisitor[IcebergType]): + def __init__(self, adds: Dict[int, List[NestedField]]): + self.adds = adds + + def schema(self, schema: Schema, struct_result: IcebergType) -> IcebergType: + fields = _ApplyChanges.add_fields(schema.as_struct().fields, self.adds.get(TABLE_ROOT_ID)) + if len(fields) > 0: + return StructType(*fields) + + return struct_result + + def struct(self, struct: StructType, field_results: List[IcebergType]) -> IcebergType: + has_change = False + new_fields: List[NestedField] = [] + for i in range(len(field_results)): + type_: Optional[IcebergType] = field_results[i] + if type_ is None: + has_change = True + continue + + field: NestedField = struct.fields[i] + new_fields.append(field) + + if has_change: + return StructType(*new_fields) + + return struct + + def field(self, field: NestedField, field_result: IcebergType) -> IcebergType: + field_id: int = field.field_id + if field_id in self.adds: + new_fields = self.adds[field_id] + if len(new_fields) > 0: + fields = _ApplyChanges.add_fields(field_result.fields, new_fields) + if len(fields) > 0: + return StructType(*fields) + + return field_result + + def list(self, list_type: ListType, element_result: IcebergType) -> IcebergType: + element_field: NestedField = list_type.element_field + element_type = self.field(element_field, element_result) + if element_type is None: + raise ValueError(f"Cannot delete element type from list: {element_field}") + + is_element_optional = not list_type.element_required + + if is_element_optional == element_field.required and list_type.element_type == element_type: + return list_type + + return ListType(list_type.element_id, element_type, is_element_optional) + + def map(self, map_type: MapType, key_result: IcebergType, value_result: IcebergType) -> IcebergType: + key_id: int = map_type.key_field.field_id + if key_id in self.adds: + raise ValueError(f"Cannot add fields to map keys: {map_type}") + + value_field: NestedField = map_type.value_field + value_type = self.field(value_field, value_result) + if value_type is None: + raise ValueError(f"Cannot delete value type from map: {value_field}") + + is_value_optional = not map_type.value_required + + if is_value_optional != value_field.required and map_type.value_type == value_type: + return map_type + + return MapType(map_type.key_id, map_type.key_field, map_type.value_id, value_type, not is_value_optional) + + def primitive(self, primitive: PrimitiveType) -> IcebergType: + return primitive + + @staticmethod + def add_fields(fields: Tuple[NestedField, ...], adds: Optional[List[NestedField]]) -> List[NestedField]: + new_fields: List[NestedField] = [] + new_fields.extend(fields) + if adds: + new_fields.extend(adds) + return new_fields diff --git a/python/pyiceberg/table/metadata.py b/python/pyiceberg/table/metadata.py index e6a3e6f16e58..690f5d4d59e2 100644 --- a/python/pyiceberg/table/metadata.py +++ b/python/pyiceberg/table/metadata.py @@ -388,12 +388,20 @@ def construct_refs(cls, table_metadata: TableMetadata) -> TableMetadata: def new_table_metadata( - schema: Schema, partition_spec: PartitionSpec, sort_order: SortOrder, location: str, properties: Properties = EMPTY_DICT + schema: Schema, + partition_spec: PartitionSpec, + sort_order: SortOrder, + location: str, + properties: Properties = EMPTY_DICT, + table_uuid: Optional[uuid.UUID] = None, ) -> TableMetadata: fresh_schema = assign_fresh_schema_ids(schema) fresh_partition_spec = assign_fresh_partition_spec_ids(partition_spec, schema, fresh_schema) fresh_sort_order = assign_fresh_sort_order_ids(sort_order, schema, fresh_schema) + if table_uuid is None: + table_uuid = uuid.uuid4() + return TableMetadataV2( location=location, schemas=[fresh_schema], @@ -405,6 +413,7 @@ def new_table_metadata( default_sort_order_id=fresh_sort_order.order_id, properties=properties, last_partition_id=fresh_partition_spec.last_assigned_field_id, + table_uuid=table_uuid, ) diff --git a/python/tests/catalog/test_base.py b/python/tests/catalog/test_base.py index b47aa5f5f799..29e93d0c9d05 100644 --- a/python/tests/catalog/test_base.py +++ b/python/tests/catalog/test_base.py @@ -42,11 +42,18 @@ from pyiceberg.io import load_file_io from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC, PartitionField, PartitionSpec from pyiceberg.schema import Schema -from pyiceberg.table import CommitTableRequest, CommitTableResponse, Table -from pyiceberg.table.metadata import TableMetadataV1 +from pyiceberg.table import ( + AddSchemaUpdate, + CommitTableRequest, + CommitTableResponse, + SetCurrentSchemaUpdate, + Table, +) +from pyiceberg.table.metadata import TableMetadata, TableMetadataV1, new_table_metadata from pyiceberg.table.sorting import UNSORTED_SORT_ORDER, SortOrder from pyiceberg.transforms import IdentityTransform from pyiceberg.typedef import EMPTY_DICT +from pyiceberg.types import IntegerType, LongType, NestedField class InMemoryCatalog(Catalog): @@ -78,29 +85,24 @@ def create_table( if namespace not in self.__namespaces: self.__namespaces[namespace] = {} + new_location = location or f's3://warehouse/{"/".join(identifier)}/data' + metadata = TableMetadataV1( + **{ + "format-version": 1, + "table-uuid": "d20125c8-7284-442c-9aea-15fee620737c", + "location": new_location, + "last-updated-ms": 1602638573874, + "last-column-id": schema.highest_field_id, + "schema": schema.model_dump(), + "partition-spec": partition_spec.model_dump()["fields"], + "properties": properties, + "current-snapshot-id": -1, + "snapshots": [{"snapshot-id": 1925, "timestamp-ms": 1602638573822}], + } + ) table = Table( identifier=identifier, - metadata=TableMetadataV1( - **{ - "format-version": 1, - "table-uuid": "d20125c8-7284-442c-9aea-15fee620737c", - "location": "s3://bucket/test/location", - "last-updated-ms": 1602638573874, - "last-column-id": 3, - "schema": { - "type": "struct", - "fields": [ - {"id": 1, "name": "x", "required": True, "type": "long"}, - {"id": 2, "name": "y", "required": True, "type": "long", "doc": "comment"}, - {"id": 3, "name": "z", "required": True, "type": "long"}, - ], - }, - "partition-spec": [{"name": "x", "transform": "identity", "source-id": 1, "field-id": 1000}], - "properties": properties, - "current-snapshot-id": -1, - "snapshots": [{"snapshot-id": 1925, "timestamp-ms": 1602638573822}], - } - ), + metadata=metadata, metadata_location=f's3://warehouse/{"/".join(identifier)}/metadata/metadata.json', io=load_file_io(), catalog=self, @@ -109,7 +111,37 @@ def create_table( return table def _commit_table(self, table_request: CommitTableRequest) -> CommitTableResponse: - raise NotImplementedError + new_metadata: Optional[TableMetadata] = None + metadata_location = "" + for update in table_request.updates: + if isinstance(update, AddSchemaUpdate): + add_schema_update: AddSchemaUpdate = update + identifier = Catalog.identifier_to_tuple(table_request.identifier) + table = self.__tables[("com", *identifier)] + new_metadata = new_table_metadata( + add_schema_update.schema_, + table.metadata.partition_specs[0], + table.sort_order(), + table.location(), + table.properties, + table.metadata.table_uuid, + ) + + table = Table( + identifier=identifier, + metadata=new_metadata, + metadata_location=f's3://warehouse/{"/".join(identifier)}/metadata/metadata.json', + io=load_file_io(), + catalog=self, + ) + + self.__tables[identifier] = table + metadata_location = f's3://warehouse/{"/".join(identifier)}/metadata/metadata.json' + + return CommitTableResponse( + metadata=new_metadata.model_dump() if new_metadata else {}, + metadata_location=metadata_location if metadata_location else "", + ) def load_table(self, identifier: Union[str, Identifier]) -> Table: identifier = Catalog.identifier_to_tuple(identifier) @@ -223,7 +255,11 @@ def catalog() -> InMemoryCatalog: TEST_TABLE_IDENTIFIER = ("com", "organization", "department", "my_table") TEST_TABLE_NAMESPACE = ("com", "organization", "department") TEST_TABLE_NAME = "my_table" -TEST_TABLE_SCHEMA = Schema(schema_id=1) +TEST_TABLE_SCHEMA = Schema( + NestedField(1, "x", LongType()), + NestedField(2, "y", LongType(), doc="comment"), + NestedField(3, "z", LongType()), +) TEST_TABLE_LOCATION = "protocol://some/location" TEST_TABLE_PARTITION_SPEC = PartitionSpec(PartitionField(name="x", transform=IdentityTransform(), source_id=1, field_id=1000)) TEST_TABLE_PROPERTIES = {"key1": "value1", "key2": "value2"} @@ -239,7 +275,7 @@ def given_catalog_has_a_table(catalog: InMemoryCatalog) -> Table: identifier=TEST_TABLE_IDENTIFIER, schema=TEST_TABLE_SCHEMA, location=TEST_TABLE_LOCATION, - partition_spec=UNPARTITIONED_PARTITION_SPEC, + partition_spec=TEST_TABLE_PARTITION_SPEC, properties=TEST_TABLE_PROPERTIES, ) @@ -474,3 +510,88 @@ def test_update_namespace_metadata_removals(catalog: InMemoryCatalog) -> None: def test_update_namespace_metadata_raises_error_when_namespace_does_not_exist(catalog: InMemoryCatalog) -> None: with pytest.raises(NoSuchNamespaceError, match=NO_SUCH_NAMESPACE_ERROR): catalog.update_namespace_properties(TEST_TABLE_NAMESPACE, updates=TEST_TABLE_PROPERTIES) + + +def test_commit_table(catalog: InMemoryCatalog) -> None: + # Given + given_table = given_catalog_has_a_table(catalog) + new_schema = Schema( + NestedField(1, "x", LongType()), + NestedField(2, "y", LongType(), doc="comment"), + NestedField(3, "z", LongType()), + NestedField(4, "add", LongType()), + ) + + # When + response = given_table.catalog._commit_table( # pylint: disable=W0212 + CommitTableRequest( + identifier=given_table.identifier[1:], + updates=[ + AddSchemaUpdate(schema=new_schema, last_column_id=new_schema.highest_field_id), + SetCurrentSchemaUpdate(schema_id=-1), + ], + ) + ) + + # Then + assert response.metadata.table_uuid == given_table.metadata.table_uuid + assert len(response.metadata.schemas) == 1 + assert response.metadata.schemas[0] == new_schema + + +def test_add_column(catalog: InMemoryCatalog) -> None: + given_table = given_catalog_has_a_table(catalog) + + given_table.update_schema().add_column(name="new_column1", type_var=IntegerType()).commit() + + assert given_table.schema() == Schema( + NestedField(field_id=1, name="x", field_type=LongType(), required=True), + NestedField(field_id=2, name="y", field_type=LongType(), required=True, doc="comment"), + NestedField(field_id=3, name="z", field_type=LongType(), required=True), + NestedField(field_id=4, name="new_column1", field_type=IntegerType(), required=False), + schema_id=0, + identifier_field_ids=[], + ) + + transaction = given_table.transaction() + transaction.update_schema().add_column(name="new_column2", type_var=IntegerType(), doc="doc").commit() + transaction.commit_transaction() + + assert given_table.schema() == Schema( + NestedField(field_id=1, name="x", field_type=LongType(), required=True), + NestedField(field_id=2, name="y", field_type=LongType(), required=True, doc="comment"), + NestedField(field_id=3, name="z", field_type=LongType(), required=True), + NestedField(field_id=4, name="new_column1", field_type=IntegerType(), required=False), + NestedField(field_id=5, name="new_column2", field_type=IntegerType(), required=False, doc="doc"), + schema_id=0, + identifier_field_ids=[], + ) + + +def test_add_column_with_statement(catalog: InMemoryCatalog) -> None: + given_table = given_catalog_has_a_table(catalog) + + with given_table.update_schema() as tx: + tx.add_column(name="new_column1", type_var=IntegerType()) + + assert given_table.schema() == Schema( + NestedField(field_id=1, name="x", field_type=LongType(), required=True), + NestedField(field_id=2, name="y", field_type=LongType(), required=True, doc="comment"), + NestedField(field_id=3, name="z", field_type=LongType(), required=True), + NestedField(field_id=4, name="new_column1", field_type=IntegerType(), required=False), + schema_id=0, + identifier_field_ids=[], + ) + + with given_table.transaction() as tx: + tx.update_schema().add_column(name="new_column2", type_var=IntegerType(), doc="doc").commit() + + assert given_table.schema() == Schema( + NestedField(field_id=1, name="x", field_type=LongType(), required=True), + NestedField(field_id=2, name="y", field_type=LongType(), required=True, doc="comment"), + NestedField(field_id=3, name="z", field_type=LongType(), required=True), + NestedField(field_id=4, name="new_column1", field_type=IntegerType(), required=False), + NestedField(field_id=5, name="new_column2", field_type=IntegerType(), required=False, doc="doc"), + schema_id=0, + identifier_field_ids=[], + ) diff --git a/python/tests/cli/test_console.py b/python/tests/cli/test_console.py index 12c82c2cde6e..45eb4dd1be8e 100644 --- a/python/tests/cli/test_console.py +++ b/python/tests/cli/test_console.py @@ -25,6 +25,7 @@ from pyiceberg.schema import Schema from pyiceberg.transforms import IdentityTransform from pyiceberg.typedef import Properties +from pyiceberg.types import LongType, NestedField from pyiceberg.utils.config import Config from tests.catalog.test_base import InMemoryCatalog @@ -62,8 +63,12 @@ def fixture_namespace_properties() -> Properties: TEST_TABLE_NAMESPACE = "default" TEST_NAMESPACE_PROPERTIES = {"location": "s3://warehouse/database/location"} TEST_TABLE_NAME = "my_table" -TEST_TABLE_SCHEMA = Schema(schema_id=0) -TEST_TABLE_LOCATION = "protocol://some/location" +TEST_TABLE_SCHEMA = Schema( + NestedField(1, "x", LongType()), + NestedField(2, "y", LongType(), doc="comment"), + NestedField(3, "z", LongType()), +) +TEST_TABLE_LOCATION = "s3://bucket/test/location" TEST_TABLE_PARTITION_SPEC = PartitionSpec(PartitionField(name="x", transform=IdentityTransform(), source_id=1, field_id=1000)) TEST_TABLE_PROPERTIES = {"read.split.target.size": "134217728"} MOCK_ENVIRONMENT = {"PYICEBERG_CATALOG__PRODUCTION__URI": "test://doesnotexist"} @@ -558,7 +563,7 @@ def test_json_describe_table(catalog: InMemoryCatalog) -> None: assert result.exit_code == 0 assert ( result.output - == """{"identifier":["default","my_table"],"metadata_location":"s3://warehouse/default/my_table/metadata/metadata.json","metadata":{"location":"s3://bucket/test/location","table-uuid":"d20125c8-7284-442c-9aea-15fee620737c","last-updated-ms":1602638573874,"last-column-id":3,"schemas":[{"type":"struct","fields":[{"id":1,"name":"x","type":"long","required":true},{"id":2,"name":"y","type":"long","required":true,"doc":"comment"},{"id":3,"name":"z","type":"long","required":true}],"schema-id":0,"identifier-field-ids":[]}],"current-schema-id":0,"partition-specs":[{"spec-id":0,"fields":[{"source-id":1,"field-id":1000,"transform":"identity","name":"x"}]}],"default-spec-id":0,"last-partition-id":1000,"properties":{},"snapshots":[{"snapshot-id":1925,"timestamp-ms":1602638573822}],"snapshot-log":[],"metadata-log":[],"sort-orders":[{"order-id":0,"fields":[]}],"default-sort-order-id":0,"refs":{},"format-version":1,"schema":{"type":"struct","fields":[{"id":1,"name":"x","type":"long","required":true},{"id":2,"name":"y","type":"long","required":true,"doc":"comment"},{"id":3,"name":"z","type":"long","required":true}],"schema-id":0,"identifier-field-ids":[]},"partition-spec":[{"name":"x","transform":"identity","source-id":1,"field-id":1000}]}}\n""" + == """{"identifier":["default","my_table"],"metadata_location":"s3://warehouse/default/my_table/metadata/metadata.json","metadata":{"location":"s3://bucket/test/location","table-uuid":"d20125c8-7284-442c-9aea-15fee620737c","last-updated-ms":1602638573874,"last-column-id":3,"schemas":[{"type":"struct","fields":[{"id":1,"name":"x","type":"long","required":true},{"id":2,"name":"y","type":"long","required":true,"doc":"comment"},{"id":3,"name":"z","type":"long","required":true}],"schema-id":0,"identifier-field-ids":[]}],"current-schema-id":0,"partition-specs":[{"spec-id":0,"fields":[{"source-id":1,"field-id":1000,"transform":"identity","name":"x"}]}],"default-spec-id":0,"last-partition-id":1000,"properties":{},"snapshots":[{"snapshot-id":1925,"timestamp-ms":1602638573822}],"snapshot-log":[],"metadata-log":[],"sort-orders":[{"order-id":0,"fields":[]}],"default-sort-order-id":0,"refs":{},"format-version":1,"schema":{"type":"struct","fields":[{"id":1,"name":"x","type":"long","required":true},{"id":2,"name":"y","type":"long","required":true,"doc":"comment"},{"id":3,"name":"z","type":"long","required":true}],"schema-id":0,"identifier-field-ids":[]},"partition-spec":[{"source-id":1,"field-id":1000,"transform":"identity","name":"x"}]}}\n""" ) diff --git a/python/tests/conftest.py b/python/tests/conftest.py index 9a560284ea8a..67fc9927809e 100644 --- a/python/tests/conftest.py +++ b/python/tests/conftest.py @@ -57,6 +57,7 @@ from pyiceberg import schema from pyiceberg.catalog import Catalog +from pyiceberg.catalog.noop import NoopCatalog from pyiceberg.io import ( GCS_ENDPOINT, GCS_PROJECT_ID, @@ -65,13 +66,14 @@ OutputFile, OutputStream, fsspec, + load_file_io, ) from pyiceberg.io.fsspec import FsspecFileIO from pyiceberg.io.pyarrow import PyArrowFile, PyArrowFileIO from pyiceberg.manifest import DataFile, FileFormat from pyiceberg.schema import Schema from pyiceberg.serializers import ToOutputFile -from pyiceberg.table import FileScanTask +from pyiceberg.table import FileScanTask, Table from pyiceberg.table.metadata import TableMetadataV2 from pyiceberg.types import ( BinaryType, @@ -194,6 +196,63 @@ def table_schema_nested() -> Schema: ) +@pytest.fixture(scope="session") +def table_schema_nested_with_struct_key_map() -> Schema: + return schema.Schema( + NestedField(field_id=1, name="foo", field_type=StringType(), required=False), + NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True), + NestedField(field_id=3, name="baz", field_type=BooleanType(), required=False), + NestedField( + field_id=4, + name="qux", + field_type=ListType(element_id=5, element_type=StringType(), element_required=True), + required=True, + ), + NestedField( + field_id=6, + name="quux", + field_type=MapType( + key_id=7, + key_type=StringType(), + value_id=8, + value_type=MapType(key_id=9, key_type=StringType(), value_id=10, value_type=IntegerType(), value_required=True), + value_required=True, + ), + required=True, + ), + NestedField( + field_id=11, + name="location", + field_type=MapType( + key_id=18, + value_id=19, + key_type=StructType( + NestedField(field_id=21, name="address", field_type=StringType(), required=False), + NestedField(field_id=22, name="city", field_type=StringType(), required=False), + NestedField(field_id=23, name="zip", field_type=IntegerType(), required=False), + ), + value_type=StructType( + NestedField(field_id=13, name="latitude", field_type=FloatType(), required=False), + NestedField(field_id=14, name="longitude", field_type=FloatType(), required=False), + ), + value_required=True, + ), + required=True, + ), + NestedField( + field_id=15, + name="person", + field_type=StructType( + NestedField(field_id=16, name="name", field_type=StringType(), required=False), + NestedField(field_id=17, name="age", field_type=IntegerType(), required=True), + ), + required=False, + ), + schema_id=1, + identifier_field_ids=[1], + ) + + @pytest.fixture(scope="session") def all_avro_types() -> Dict[str, Any]: return { @@ -1561,3 +1620,15 @@ def example_task(data_file: str) -> FileScanTask: return FileScanTask( data_file=DataFile(file_path=data_file, file_format=FileFormat.PARQUET, file_size_in_bytes=1925), ) + + +@pytest.fixture +def table(example_table_metadata_v2: Dict[str, Any]) -> Table: + table_metadata = TableMetadataV2(**example_table_metadata_v2) + return Table( + identifier=("database", "table"), + metadata=table_metadata, + metadata_location=f"{table_metadata.location}/uuid.metadata.json", + io=load_file_io(), + catalog=NoopCatalog("NoopCatalog"), + ) diff --git a/python/tests/table/test_init.py b/python/tests/table/test_init.py index 2587fb76d923..b25e445032fd 100644 --- a/python/tests/table/test_init.py +++ b/python/tests/table/test_init.py @@ -15,19 +15,18 @@ # specific language governing permissions and limitations # under the License. # pylint:disable=redefined-outer-name -from typing import Any, Dict +from typing import Dict import pytest from sortedcontainers import SortedList -from pyiceberg.catalog.noop import NoopCatalog from pyiceberg.expressions import ( AlwaysTrue, And, EqualTo, In, ) -from pyiceberg.io import PY_IO_IMPL, load_file_io +from pyiceberg.io import PY_IO_IMPL from pyiceberg.manifest import ( DataFile, DataFileContent, @@ -41,9 +40,10 @@ SetPropertiesUpdate, StaticTable, Table, + UpdateSchema, _match_deletes_to_datafile, ) -from pyiceberg.table.metadata import INITIAL_SEQUENCE_NUMBER, TableMetadataV2 +from pyiceberg.table.metadata import INITIAL_SEQUENCE_NUMBER from pyiceberg.table.snapshots import ( Operation, Snapshot, @@ -57,19 +57,25 @@ SortOrder, ) from pyiceberg.transforms import BucketTransform, IdentityTransform -from pyiceberg.types import LongType, NestedField - - -@pytest.fixture -def table(example_table_metadata_v2: Dict[str, Any]) -> Table: - table_metadata = TableMetadataV2(**example_table_metadata_v2) - return Table( - identifier=("database", "table"), - metadata=table_metadata, - metadata_location=f"{table_metadata.location}/uuid.metadata.json", - io=load_file_io(), - catalog=NoopCatalog("NoopCatalog"), - ) +from pyiceberg.types import ( + BinaryType, + BooleanType, + DateType, + DoubleType, + FloatType, + IntegerType, + ListType, + LongType, + MapType, + NestedField, + PrimitiveType, + StringType, + StructType, + TimestampType, + TimestamptzType, + TimeType, + UUIDType, +) def test_schema(table: Table) -> None: @@ -388,3 +394,176 @@ def test_match_deletes_to_datafile_duplicate_number() -> None: def test_serialize_set_properties_updates() -> None: assert SetPropertiesUpdate(updates={"abc": "🤪"}).model_dump_json() == """{"action":"set-properties","updates":{"abc":"🤪"}}""" + + +def test_add_column(table_schema_simple: Schema, table: Table) -> None: + update = UpdateSchema(table_schema_simple, table) + update.add_column(name="b", type_var=IntegerType()) + apply_schema: Schema = update._apply() # pylint: disable=W0212 + assert len(apply_schema.fields) == 4 + + assert apply_schema == Schema( + NestedField(field_id=1, name="foo", field_type=StringType(), required=False), + NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True), + NestedField(field_id=3, name="baz", field_type=BooleanType(), required=False), + NestedField(field_id=4, name="b", field_type=IntegerType(), required=False), + ) + assert apply_schema.schema_id == 0 + assert apply_schema.highest_field_id == 4 + + +def test_add_primitive_type_column(table_schema_simple: Schema, table: Table) -> None: + primitive_type: Dict[str, PrimitiveType] = { + "boolean": BooleanType(), + "int": IntegerType(), + "long": LongType(), + "float": FloatType(), + "double": DoubleType(), + "date": DateType(), + "time": TimeType(), + "timestamp": TimestampType(), + "timestamptz": TimestamptzType(), + "string": StringType(), + "uuid": UUIDType(), + "binary": BinaryType(), + } + + for name, type_ in primitive_type.items(): + field_name = f"new_column_{name}" + update = UpdateSchema(table_schema_simple, table) + update.add_column(parent=None, name=field_name, type_var=type_, doc=f"new_column_{name}") + new_schema = update._apply() # pylint: disable=W0212 + + field: NestedField = new_schema.find_field(field_name) + assert field.field_type == type_ + assert field.doc == f"new_column_{name}" + + +def test_add_nested_type_column(table_schema_simple: Schema, table: Table) -> None: + # add struct type column + field_name = "new_column_struct" + update = UpdateSchema(table_schema_simple, table) + struct_ = StructType( + NestedField(1, "lat", DoubleType()), + NestedField(2, "long", DoubleType()), + ) + update.add_column(parent=None, name=field_name, type_var=struct_) + schema_ = update._apply() # pylint: disable=W0212 + field: NestedField = schema_.find_field(field_name) + assert field.field_type == StructType( + NestedField(5, "lat", DoubleType()), + NestedField(6, "long", DoubleType()), + ) + assert schema_.highest_field_id == 6 + + +def test_add_nested_map_type_column(table_schema_simple: Schema, table: Table) -> None: + # add map type column + field_name = "new_column_map" + update = UpdateSchema(table_schema_simple, table) + map_ = MapType(1, StringType(), 2, IntegerType(), False) + update.add_column(parent=None, name=field_name, type_var=map_) + new_schema = update._apply() # pylint: disable=W0212 + field: NestedField = new_schema.find_field(field_name) + assert field.field_type == MapType(5, StringType(), 6, IntegerType(), False) + assert new_schema.highest_field_id == 6 + + +def test_add_nested_list_type_column(table_schema_simple: Schema, table: Table) -> None: + # add list type column + field_name = "new_column_list" + update = UpdateSchema(table_schema_simple, table) + list_ = ListType( + element_id=101, + element_type=StructType( + NestedField(102, "lat", DoubleType()), + NestedField(103, "long", DoubleType()), + ), + element_required=False, + ) + update.add_column(parent=None, name=field_name, type_var=list_) + new_schema = update._apply() # pylint: disable=W0212 + field: NestedField = new_schema.find_field(field_name) + assert field.field_type == ListType( + element_id=5, + element_type=StructType( + NestedField(6, "lat", DoubleType()), + NestedField(7, "long", DoubleType()), + ), + element_required=False, + ) + assert new_schema.highest_field_id == 7 + + +def test_add_field_to_map_key(table_schema_nested_with_struct_key_map: Schema, table: Table) -> None: + with pytest.raises(ValueError) as exc_info: + update = UpdateSchema(table_schema_nested_with_struct_key_map, table) + update.add_column(name="b", type_var=IntegerType(), parent="location.key")._apply() # pylint: disable=W0212 + assert "Cannot add fields to map keys" in str(exc_info.value) + + +def test_add_already_exists(table_schema_nested: Schema, table: Table) -> None: + with pytest.raises(ValueError) as exc_info: + update = UpdateSchema(table_schema_nested, table) + update.add_column("foo", IntegerType()) + assert "already exists: foo" in str(exc_info.value) + + with pytest.raises(ValueError) as exc_info: + update = UpdateSchema(table_schema_nested, table) + update.add_column(name="latitude", type_var=IntegerType(), parent="location") + assert "already exists: location.lat" in str(exc_info.value) + + +def test_add_to_non_struct_type(table_schema_simple: Schema, table: Table) -> None: + with pytest.raises(ValueError) as exc_info: + update = UpdateSchema(table_schema_simple, table) + update.add_column(name="lat", type_var=IntegerType(), parent="foo") + assert "Cannot add column to non-struct type" in str(exc_info.value) + + +def test_add_required_column(table: Table) -> None: + schema_ = Schema( + NestedField(field_id=1, name="a", field_type=BooleanType(), required=False), schema_id=1, identifier_field_ids=[] + ) + + with pytest.raises(ValueError) as exc_info: + update = UpdateSchema(schema_, table) + update.add_column(name="data", type_var=IntegerType(), required=True) + assert "Incompatible change: cannot add required column: data" in str(exc_info.value) + + new_schema = ( + UpdateSchema(schema_, table) # pylint: disable=W0212 + .allow_incompatible_changes() + .add_column(name="data", type_var=IntegerType(), required=True) + ._apply() + ) + assert new_schema == Schema( + NestedField(field_id=1, name="a", field_type=BooleanType(), required=False), + NestedField(field_id=2, name="data", field_type=IntegerType(), required=True), + schema_id=0, + identifier_field_ids=[], + ) + + +def test_add_required_column_case_insensitive(table: Table) -> None: + schema_ = Schema( + NestedField(field_id=1, name="id", field_type=BooleanType(), required=False), schema_id=1, identifier_field_ids=[] + ) + + with pytest.raises(ValueError) as exc_info: + update = UpdateSchema(schema_, table) + update.allow_incompatible_changes().case_sensitive(False).add_column(name="ID", type_var=IntegerType(), required=True) + assert "already exists: ID" in str(exc_info.value) + + new_schema = ( + UpdateSchema(schema_, table) # pylint: disable=W0212 + .allow_incompatible_changes() + .add_column(name="ID", type_var=IntegerType(), required=True) + ._apply() + ) + assert new_schema == Schema( + NestedField(field_id=1, name="id", field_type=BooleanType(), required=False), + NestedField(field_id=2, name="ID", field_type=IntegerType(), required=True), + schema_id=0, + identifier_field_ids=[], + ) diff --git a/python/tests/test_integration.py b/python/tests/test_integration.py index a63436bdaead..acd694677463 100644 --- a/python/tests/test_integration.py +++ b/python/tests/test_integration.py @@ -25,7 +25,7 @@ from pyarrow.fs import S3FileSystem from pyiceberg.catalog import Catalog, load_catalog -from pyiceberg.exceptions import NoSuchTableError +from pyiceberg.exceptions import CommitFailedException, NoSuchTableError from pyiceberg.expressions import ( And, EqualTo, @@ -40,10 +40,14 @@ from pyiceberg.table import Table from pyiceberg.types import ( BooleanType, + DoubleType, + FixedType, IntegerType, + LongType, NestedField, StringType, TimestampType, + UUIDType, ) @@ -352,3 +356,89 @@ def test_unpartitioned_fixed_table(catalog: Catalog) -> None: b"12345678901234567ass12345", b"qweeqwwqq1231231231231111", ] + + +@pytest.mark.integration +def test_schema_evolution(catalog: Catalog) -> None: + try: + catalog.drop_table("default.test_schema_evolution") + except NoSuchTableError: + pass + + schema = Schema( + NestedField(field_id=1, name="col_uuid", field_type=UUIDType(), required=False), + NestedField(field_id=2, name="col_fixed", field_type=FixedType(25), required=False), + ) + + t = catalog.create_table(identifier="default.test_schema_evolution", schema=schema) + + assert t.schema() == schema + + with t.update_schema() as tx: + tx.add_column("col_string", StringType()) + + assert t.schema() == Schema( + NestedField(field_id=1, name="col_uuid", field_type=UUIDType(), required=False), + NestedField(field_id=2, name="col_fixed", field_type=FixedType(25), required=False), + NestedField(field_id=3, name="col_string", field_type=StringType(), required=False), + schema_id=1, + ) + + +@pytest.mark.integration +def test_schema_evolution_via_transaction(catalog: Catalog) -> None: + try: + catalog.drop_table("default.test_schema_evolution") + except NoSuchTableError: + pass + + schema = Schema( + NestedField(field_id=1, name="col_uuid", field_type=UUIDType(), required=False), + NestedField(field_id=2, name="col_fixed", field_type=FixedType(25), required=False), + ) + + tbl = catalog.create_table(identifier="default.test_schema_evolution", schema=schema) + + assert tbl.schema() == schema + + with tbl.transaction() as tx: + tx.update_schema().add_column("col_string", StringType()).commit() + + assert tbl.schema() == Schema( + NestedField(field_id=1, name="col_uuid", field_type=UUIDType(), required=False), + NestedField(field_id=2, name="col_fixed", field_type=FixedType(25), required=False), + NestedField(field_id=3, name="col_string", field_type=StringType(), required=False), + schema_id=1, + ) + + tbl.update_schema().add_column("col_integer", IntegerType()).commit() + + assert tbl.schema() == Schema( + NestedField(field_id=1, name="col_uuid", field_type=UUIDType(), required=False), + NestedField(field_id=2, name="col_fixed", field_type=FixedType(25), required=False), + NestedField(field_id=3, name="col_string", field_type=StringType(), required=False), + NestedField(field_id=4, name="col_integer", field_type=IntegerType(), required=False), + schema_id=1, + ) + + with pytest.raises(CommitFailedException) as exc_info: + with tbl.transaction() as tx: + # Start a new update + schema_update = tx.update_schema() + + # Do a concurrent update + tbl.update_schema().add_column("col_long", LongType()).commit() + + # stage another update in the transaction + schema_update.add_column("col_double", DoubleType()).commit() + + assert "Requirement failed: current schema changed: expected id 2 != 3" in str(exc_info.value) + + assert tbl.schema() == Schema( + NestedField(field_id=1, name="col_uuid", field_type=UUIDType(), required=False), + NestedField(field_id=2, name="col_fixed", field_type=FixedType(25), required=False), + NestedField(field_id=3, name="col_string", field_type=StringType(), required=False), + NestedField(field_id=4, name="col_integer", field_type=IntegerType(), required=False), + NestedField(field_id=5, name="col_long", field_type=LongType(), required=False), + schema_id=1, + ) diff --git a/python/tests/test_schema.py b/python/tests/test_schema.py index 57f194734694..50d788b953cd 100644 --- a/python/tests/test_schema.py +++ b/python/tests/test_schema.py @@ -334,78 +334,6 @@ def test_schema_find_field_type_by_id(table_schema_simple: Schema) -> None: assert index[3] == NestedField(field_id=3, name="baz", field_type=BooleanType(), required=False) -def test_index_by_id_schema_visitor(table_schema_nested: Schema) -> None: - """Test the index_by_id function that uses the IndexById schema visitor""" - assert schema.index_by_id(table_schema_nested) == { - 1: NestedField(field_id=1, name="foo", field_type=StringType(), required=False), - 2: NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True), - 3: NestedField(field_id=3, name="baz", field_type=BooleanType(), required=False), - 4: NestedField( - field_id=4, - name="qux", - field_type=ListType(element_id=5, element_type=StringType(), element_required=True), - required=True, - ), - 5: NestedField(field_id=5, name="element", field_type=StringType(), required=True), - 6: NestedField( - field_id=6, - name="quux", - field_type=MapType( - key_id=7, - key_type=StringType(), - value_id=8, - value_type=MapType(key_id=9, key_type=StringType(), value_id=10, value_type=IntegerType(), value_required=True), - value_required=True, - ), - required=True, - ), - 7: NestedField(field_id=7, name="key", field_type=StringType(), required=True), - 8: NestedField( - field_id=8, - name="value", - field_type=MapType(key_id=9, key_type=StringType(), value_id=10, value_type=IntegerType(), value_required=True), - required=True, - ), - 9: NestedField(field_id=9, name="key", field_type=StringType(), required=True), - 10: NestedField(field_id=10, name="value", field_type=IntegerType(), required=True), - 11: NestedField( - field_id=11, - name="location", - field_type=ListType( - element_id=12, - element_type=StructType( - NestedField(field_id=13, name="latitude", field_type=FloatType(), required=False), - NestedField(field_id=14, name="longitude", field_type=FloatType(), required=False), - ), - element_required=True, - ), - required=True, - ), - 12: NestedField( - field_id=12, - name="element", - field_type=StructType( - NestedField(field_id=13, name="latitude", field_type=FloatType(), required=False), - NestedField(field_id=14, name="longitude", field_type=FloatType(), required=False), - ), - required=True, - ), - 13: NestedField(field_id=13, name="latitude", field_type=FloatType(), required=False), - 14: NestedField(field_id=14, name="longitude", field_type=FloatType(), required=False), - 15: NestedField( - field_id=15, - name="person", - field_type=StructType( - NestedField(field_id=16, name="name", field_type=StringType(), required=False), - NestedField(field_id=17, name="age", field_type=IntegerType(), required=True), - ), - required=False, - ), - 16: NestedField(field_id=16, name="name", field_type=StringType(), required=False), - 17: NestedField(field_id=17, name="age", field_type=IntegerType(), required=True), - } - - def test_index_by_id_schema_visitor_raise_on_unregistered_type() -> None: """Test raising a NotImplementedError when an invalid type is provided to the index_by_id function""" with pytest.raises(NotImplementedError) as exc_info: