diff --git a/python/Makefile b/python/Makefile index 3597a34cf1f8..80e6f4dee78e 100644 --- a/python/Makefile +++ b/python/Makefile @@ -30,7 +30,7 @@ lint: poetry run pre-commit run --all-files test: - poetry run pytest tests/ -m "unmarked or parametrize" ${PYTEST_ARGS} + poetry run pytest tests/ -m "(unmarked or parametrize) and not integration" ${PYTEST_ARGS} test-s3: sh ./dev/run-minio.sh @@ -53,15 +53,18 @@ test-adlfs: sh ./dev/run-azurite.sh poetry run pytest tests/ -m adlfs ${PYTEST_ARGS} +test-gcs: + sh ./dev/run-gcs-server.sh + poetry run pytest tests/ -m gcs ${PYTEST_ARGS} + test-coverage: - sh ./dev/run-minio.sh + docker-compose -f dev/docker-compose-integration.yml kill + docker-compose -f dev/docker-compose-integration.yml rm -f + docker-compose -f dev/docker-compose-integration.yml up -d sh ./dev/run-azurite.sh sh ./dev/run-gcs-server.sh - poetry run coverage run --source=pyiceberg/ -m pytest tests/ -m "not integration" ${PYTEST_ARGS} + docker-compose -f dev/docker-compose-integration.yml exec -T spark-iceberg ipython ./provision.py + poetry run coverage run --source=pyiceberg/ -m pytest tests/ ${PYTEST_ARGS} poetry run coverage report -m --fail-under=90 poetry run coverage html poetry run coverage xml - -test-gcs: - sh ./dev/run-gcs-server.sh - poetry run pytest tests/ -m gcs ${PYTEST_ARGS} diff --git a/python/mkdocs/docs/api.md b/python/mkdocs/docs/api.md index f0b2873c038e..55eadc5f5b45 100644 --- a/python/mkdocs/docs/api.md +++ b/python/mkdocs/docs/api.md @@ -42,65 +42,36 @@ Then load the `prod` catalog: ```python from pyiceberg.catalog import load_catalog -catalog = load_catalog("prod") - -catalog.list_namespaces() +catalog = load_catalog( + "docs", + **{ + "uri": "http://127.0.0.1:8181", + "s3.endpoint": "http://127.0.0.1:9000", + "py-io-impl": "pyiceberg.io.pyarrow.PyArrowFileIO", + "s3.access-key-id": "admin", + "s3.secret-access-key": "password", + } +) ``` -Returns two namespaces: +Let's create a namespace: ```python -[("default",), ("nyc",)] +catalog.create_namespace("docs_example") ``` -Listing the tables in the `nyc` namespace: +And then list them: ```python -catalog.list_tables("nyc") -``` - -Returns as list with tuples, containing a single table `taxis`: +ns = catalog.list_namespaces() -```python -[("nyc", "taxis")] +assert ns == [("docs_example",)] ``` -## Load a table - -### From a catalog - -Loading the `taxis` table: +And then list tables in the namespace: ```python -catalog.load_table("nyc.taxis") -# Equivalent to: -catalog.load_table(("nyc", "taxis")) -# The tuple syntax can be used if the namespace or table contains a dot. -``` - -This returns a `Table` that represents an Iceberg table that can be queried and altered. - -### Directly from a metadata file - -To load a table directly from a metadata file (i.e., **without** using a catalog), you can use a `StaticTable` as follows: - -```python -from pyiceberg.table import StaticTable - -table = StaticTable.from_metadata( - "s3a://warehouse/wh/nyc.db/taxis/metadata/00002-6ea51ce3-62aa-4197-9cf8-43d07c3440ca.metadata.json" -) -``` - -For the rest, this table behaves similarly as a table loaded using a catalog. Note that `StaticTable` is intended to be _read only_. - -Any properties related to file IO can be passed accordingly: - -```python -table = StaticTable.from_metadata( - "s3a://warehouse/wh/nyc.db/taxis/metadata/00002-6ea51ce3-62aa-4197-9cf8-43d07c3440ca.metadata.json", - {PY_IO_IMPL: "pyiceberg.some.FileIO.class"}, -) +catalog.list_tables("docs_example") ``` ## Create a table @@ -108,17 +79,31 @@ table = StaticTable.from_metadata( To create a table from a catalog: ```python -from pyiceberg.catalog import load_catalog from pyiceberg.schema import Schema -from pyiceberg.types import TimestampType, DoubleType, StringType, NestedField +from pyiceberg.types import ( + TimestampType, + FloatType, + DoubleType, + StringType, + NestedField, + StructType, +) schema = Schema( + NestedField(field_id=1, name="datetime", field_type=TimestampType(), required=True), + NestedField(field_id=2, name="symbol", field_type=StringType(), required=True), + NestedField(field_id=3, name="bid", field_type=FloatType(), required=False), + NestedField(field_id=4, name="ask", field_type=DoubleType(), required=False), NestedField( - field_id=1, name="datetime", field_type=TimestampType(), required=False + field_id=5, + name="details", + field_type=StructType( + NestedField( + field_id=4, name="created_by", field_type=StringType(), required=False + ), + ), + required=False, ), - NestedField(field_id=2, name="bid", field_type=DoubleType(), required=False), - NestedField(field_id=3, name="ask", field_type=DoubleType(), required=False), - NestedField(field_id=4, name="symbol", field_type=StringType(), required=False), ) from pyiceberg.partitioning import PartitionSpec, PartitionField @@ -133,52 +118,132 @@ partition_spec = PartitionSpec( from pyiceberg.table.sorting import SortOrder, SortField from pyiceberg.transforms import IdentityTransform -sort_order = SortOrder(SortField(source_id=4, transform=IdentityTransform())) - -catalog = load_catalog("prod") +# Sort on the symbol +sort_order = SortOrder(SortField(source_id=2, transform=IdentityTransform())) catalog.create_table( - identifier="default.bids", - location="/Users/fokkodriesprong/Desktop/docker-spark-iceberg/wh/bids/", + identifier="docs_example.bids", schema=schema, partition_spec=partition_spec, sort_order=sort_order, ) ``` -### Update table schema +## Load a table -Add new columns through the `Transaction` or `UpdateSchema` API: +### Catalog table -Use the Transaction API: +Loading the `bids` table: + +```python +table = catalog.load_table("docs_example.bids") +# Equivalent to: +table = catalog.load_table(("docs_example", "bids")) +# The tuple syntax can be used if the namespace or table contains a dot. +``` + +This returns a `Table` that represents an Iceberg table that can be queried and altered. + +### Static table + +To load a table directly from a metadata file (i.e., **without** using a catalog), you can use a `StaticTable` as follows: + +```python +from pyiceberg.table import StaticTable + +static_table = StaticTable.from_metadata( + "s3://warehouse/wh/nyc.db/taxis/metadata/00002-6ea51ce3-62aa-4197-9cf8-43d07c3440ca.metadata.json" +) +``` + +The static-table is considered read-only. + +## Schema evolution + +PyIceberg supports full schema evolution through the Python API. It takes care of setting the field-IDs and makes sure that only non-breaking changes are done (can be overriden). + +In the examples below, the `.update_schema()` is called from the table itself. + +```python +with table.update_schema() as update: + update.add_column("some_field", IntegerType(), "doc") +``` + +You can also initiate a transaction if you want to make more changes than just evolving the schema: ```python with table.transaction() as transaction: - transaction.update_schema().add_column("x", IntegerType(), "doc").commit() + with transaction.update_schema() as update_schema: + update.add_column("some_other_field", IntegerType(), "doc") + # ... Update properties etc +``` + +### Add column + +Using `add_column` you can add a column, without having to worry about the field-id: + +```python +with table.update_schema() as update: + update.add_column("retries", IntegerType(), "Number of retries to place the bid") + # In a struct + update.add_column("details.confirmed_by", StringType(), "Name of the exchange") +``` + +### Rename column + +Renaming a field in an Iceberg table is simple: + +```python +with table.update_schema() as update: + update.rename("retries", "num_retries") + # This will rename `confirmed_by` to `exchange` + update.rename("properties.confirmed_by", "exchange") ``` -Or, without a context manager: +### Move column + +Move a field inside of struct: ```python -transaction = table.transaction() -transaction.update_schema().add_column("x", IntegerType(), "doc").commit() -transaction.commit_transaction() +with table.update_schema() as update: + update.move_first("symbol") + update.move_after("bid", "ask") + # This will move `confirmed_by` before `exchange` + update.move_before("details.created_by", "details.exchange") ``` -Or, use the UpdateSchema API directly: +### Update column + +Update a fields' type, description or required. ```python with table.update_schema() as update: - update.add_column("x", IntegerType(), "doc") + # Promote a float to a double + update.update_column("bid", field_type=DoubleType()) + # Make a field optional + update.update_column("symbol", required=False) + # Update the documentation + update.update_column("symbol", doc="Name of the share on the exchange") +``` + +Be careful, some operations are not compatible, but can still be done at your own risk by setting `allow_incompatible_changes`: + +```python +with table.update_schema(allow_incompatible_changes=True) as update: + # Incompatible change, cannot require an optional field + update.update_column("symbol", required=True) ``` -Or, without a context manager: +### Delete column + +Delete a field, careful this is a incompatible change (readers/writers might expect this field): ```python -table.update_schema().add_column("x", IntegerType(), "doc").commit() +with table.update_schema(allow_incompatible_changes=True) as update: + update.delete_column("some_field") ``` -### Update table properties +## Table properties Set and remove properties through the `Transaction` API: @@ -194,7 +259,7 @@ with table.transaction() as transaction: assert table.properties == {} ``` -Or, without a context manager: +Or, without context manager: ```python table = table.transaction().set_properties(abc="def").commit_transaction() @@ -235,7 +300,7 @@ The low level API `plan_files` methods returns a set of tasks that provide the f ```json [ - "s3a://warehouse/wh/nyc/taxis/data/00003-4-42464649-92dd-41ad-b83b-dea1a2fe4b58-00001.parquet" + "s3://warehouse/wh/nyc/taxis/data/00003-4-42464649-92dd-41ad-b83b-dea1a2fe4b58-00001.parquet" ] ``` @@ -343,19 +408,17 @@ Dataset( Using [Ray Dataset API](https://docs.ray.io/en/latest/data/api/dataset.html) to interact with the dataset: ```python -print( - ray_dataset.take(2) -) +print(ray_dataset.take(2)) [ { - 'VendorID': 2, - 'tpep_pickup_datetime': datetime.datetime(2008, 12, 31, 23, 23, 50, tzinfo=), - 'tpep_dropoff_datetime': datetime.datetime(2009, 1, 1, 0, 34, 31, tzinfo=) + "VendorID": 2, + "tpep_pickup_datetime": datetime.datetime(2008, 12, 31, 23, 23, 50), + "tpep_dropoff_datetime": datetime.datetime(2009, 1, 1, 0, 34, 31), }, { - 'VendorID': 2, - 'tpep_pickup_datetime': datetime.datetime(2008, 12, 31, 23, 5, 3, tzinfo=), - 'tpep_dropoff_datetime': datetime.datetime(2009, 1, 1, 16, 10, 18, tzinfo=) - } + "VendorID": 2, + "tpep_pickup_datetime": datetime.datetime(2008, 12, 31, 23, 5, 3), + "tpep_dropoff_datetime": datetime.datetime(2009, 1, 1, 16, 10, 18), + }, ] ``` diff --git a/python/mkdocs/docs/contributing.md b/python/mkdocs/docs/contributing.md index 989cbbea44f8..87a8cc701bb0 100644 --- a/python/mkdocs/docs/contributing.md +++ b/python/mkdocs/docs/contributing.md @@ -160,4 +160,4 @@ PyIceberg offers support from Python 3.8 onwards, we can't use the [type hints f ## Third party libraries -PyIceberg naturally integrates into the rich Python ecosystem, however it is important to be hesistant to add third party packages. Adding a lot of packages makes the library heavyweight, and causes incompatibilities with other projects if they use a different version of the library. Also, big libraries such as `s3fs`, `adlfs`, `pyarrow`, `thrift` should be optional to avoid downloading everything, while not being sure if is actually being used. +PyIceberg naturally integrates into the rich Python ecosystem, however it is important to be hesitant adding third party packages. Adding a lot of packages makes the library heavyweight, and causes incompatibilities with other projects if they use a different version of the library. Also, big libraries such as `s3fs`, `adlfs`, `pyarrow`, `thrift` should be optional to avoid downloading everything, while not being sure if is actually being used. diff --git a/python/pyiceberg/schema.py b/python/pyiceberg/schema.py index 26a0559cf910..28101809c76f 100644 --- a/python/pyiceberg/schema.py +++ b/python/pyiceberg/schema.py @@ -261,6 +261,21 @@ def accessor_for_field(self, field_id: int) -> Accessor: return self._lazy_id_to_accessor[field_id] + def identifier_field_names(self) -> Set[str]: + """Return the names of the identifier fields. + + Returns: + Set of names of the identifier fields + """ + ids = set() + for field_id in self.identifier_field_ids: + column_name = self.find_column_name(field_id) + if column_name is None: + raise ValueError(f"Could not find identifier column id: {field_id}") + ids.add(column_name) + + return ids + def select(self, *names: str, case_sensitive: bool = True) -> Schema: """Return a new schema instance pruned to a subset of columns. @@ -996,12 +1011,6 @@ def __init__(self) -> None: self._field_names: List[str] = [] self._short_field_names: List[str] = [] - def before_map_key(self, key: NestedField) -> None: - self.before_field(key) - - def after_map_key(self, key: NestedField) -> None: - self.after_field(key) - def before_map_value(self, value: NestedField) -> None: if not isinstance(value.field_type, StructType): self._short_field_names.append(value.name) @@ -1265,6 +1274,16 @@ def primitive(self, primitive: PrimitiveType) -> PrimitiveType: def prune_columns(schema: Schema, selected: Set[int], select_full_types: bool = True) -> Schema: + """Prunes a column by only selecting a set of field-ids. + + Args: + schema: The schema to be pruned. + selected: The field-ids to be included. + select_full_types: Return the full struct when a subset is recorded + + Returns: + The pruned schema. + """ result = visit(schema.as_struct(), _PruneColumnsVisitor(selected, select_full_types)) return Schema( *(result or StructType()).fields, diff --git a/python/pyiceberg/table/__init__.py b/python/pyiceberg/table/__init__.py index d24550cb7e27..b905c955c848 100644 --- a/python/pyiceberg/table/__init__.py +++ b/python/pyiceberg/table/__init__.py @@ -18,6 +18,7 @@ import itertools from abc import ABC, abstractmethod +from copy import copy from dataclasses import dataclass from enum import Enum from functools import cached_property @@ -40,6 +41,7 @@ from pydantic import Field, SerializeAsAny from sortedcontainers import SortedList +from pyiceberg.exceptions import ResolveError, ValidationError from pyiceberg.expressions import ( AlwaysTrue, And, @@ -63,7 +65,7 @@ Schema, SchemaVisitor, assign_fresh_schema_ids, - index_by_name, + promote, visit, ) from pyiceberg.table.metadata import INITIAL_SEQUENCE_NUMBER, TableMetadata @@ -94,7 +96,6 @@ from pyiceberg.catalog import Catalog - ALWAYS_TRUE = AlwaysTrue() TABLE_ROOT_ID = -1 @@ -158,7 +159,7 @@ def _append_requirements(self, *new_requirements: TableRequirement) -> Transacti """ for requirement in new_requirements: type_new_requirement = type(requirement) - if any(type(update) == type_new_requirement for update in self._updates): + if any(type(requirement) == type_new_requirement for update in self._requirements): raise ValueError(f"Requirements in a single commit need to be unique, duplicate: {type_new_requirement}") self._requirements = self._requirements + new_requirements return self @@ -193,7 +194,7 @@ def update_schema(self) -> UpdateSchema: Returns: A new UpdateSchema. """ - return UpdateSchema(self._table.schema(), self._table, self) + return UpdateSchema(self._table, self) def remove_properties(self, *removals: str) -> Transaction: """Remove properties. @@ -225,17 +226,10 @@ def commit_transaction(self) -> Table: """ # Strip the catalog name if len(self._updates) > 0: - response = self._table.catalog._commit_table( # pylint: disable=W0212 - CommitTableRequest( - identifier=self._table.identifier[1:], - requirements=self._requirements, - updates=self._updates, - ) + self._table._do_commit( # pylint: disable=W0212 + updates=self._updates, + requirements=self._requirements, ) - # Update the metadata with the new one - self._table.metadata = response.metadata - self._table.metadata_location = response.metadata_location - return self._table else: return self._table @@ -410,8 +404,8 @@ class AssertDefaultSortOrderId(TableRequirement): class CommitTableRequest(IcebergBaseModel): identifier: Identifier = Field() - requirements: List[SerializeAsAny[TableRequirement]] = Field(default_factory=list) - updates: List[SerializeAsAny[TableUpdate]] = Field(default_factory=list) + requirements: Tuple[SerializeAsAny[TableRequirement], ...] = Field(default_factory=tuple) + updates: Tuple[SerializeAsAny[TableUpdate], ...] = Field(default_factory=tuple) class CommitTableResponse(IcebergBaseModel): @@ -527,8 +521,15 @@ def history(self) -> List[SnapshotLogEntry]: """Get the snapshot history of this table.""" return self.metadata.snapshot_log - def update_schema(self) -> UpdateSchema: - return UpdateSchema(self.schema(), self) + def update_schema(self, allow_incompatible_changes: bool = False, case_sensitive: bool = True) -> UpdateSchema: + return UpdateSchema(self, allow_incompatible_changes=allow_incompatible_changes, case_sensitive=case_sensitive) + + def _do_commit(self, updates: Tuple[TableUpdate, ...], requirements: Tuple[TableRequirement, ...]) -> None: + response = self.catalog._commit_table( # pylint: disable=W0212 + CommitTableRequest(identifier=self.identifier[1:], updates=updates, requirements=requirements) + ) # pylint: disable=W0212 + self.metadata = response.metadata + self.metadata_location = response.metadata_location def __eq__(self, other: Any) -> bool: """Return the equality of two instances of the Table class.""" @@ -889,28 +890,69 @@ def to_ray(self) -> ray.data.dataset.Dataset: return ray.data.from_arrow(self.to_arrow()) +class MoveOperation(Enum): + First = 1 + Before = 2 + After = 3 + + +@dataclass +class Move: + field_id: int + full_name: str + op: MoveOperation + other_field_id: Optional[int] = None + + class UpdateSchema: _table: Table _schema: Schema _last_column_id: itertools.count[int] - _identifier_field_names: List[str] - _adds: Dict[int, List[NestedField]] - _added_name_to_id: Dict[str, int] - _id_to_parent: Dict[int, str] + _identifier_field_names: Set[str] + + _adds: Dict[int, List[NestedField]] = {} + _updates: Dict[int, NestedField] = {} + _deletes: Set[int] = set() + _moves: Dict[int, List[Move]] = {} + + _added_name_to_id: Dict[str, int] = {} + # Part of https://github.com/apache/iceberg/pull/8393 + _id_to_parent: Dict[int, str] = {} _allow_incompatible_changes: bool _case_sensitive: bool _transaction: Optional[Transaction] - def __init__(self, schema: Schema, table: Table, transaction: Optional[Transaction] = None): + def __init__( + self, + table: Table, + transaction: Optional[Transaction] = None, + allow_incompatible_changes: bool = False, + case_sensitive: bool = True, + ) -> None: self._table = table - self._schema = schema - self._last_column_id = itertools.count(schema.highest_field_id + 1) - self._identifier_field_names = schema.column_names + self._schema = table.schema() + self._last_column_id = itertools.count(table.metadata.last_column_id + 1) + self._identifier_field_names = self._schema.identifier_field_names() + self._adds = {} + self._updates = {} + self._deletes = set() + self._moves = {} + self._added_name_to_id = {} - self._id_to_parent = {} - self._allow_incompatible_changes = False - self._case_sensitive = True + + def get_column_name(field_id: int) -> str: + column_name = self._schema.find_column_name(column_id=field_id) + if column_name is None: + raise ValueError(f"Could not find field-id: {field_id}") + return column_name + + self._id_to_parent = { + field_id: get_column_name(parent_field_id) for field_id, parent_field_id in self._schema._lazy_id_to_parent.items() + } + + self._allow_incompatible_changes = allow_incompatible_changes + self._case_sensitive = case_sensitive self._transaction = transaction def __exit__(self, _: Any, value: Any, traceback: Any) -> None: @@ -934,206 +976,583 @@ def case_sensitive(self, case_sensitive: bool) -> UpdateSchema: return self def add_column( - self, name: str, type_var: IcebergType, doc: Optional[str] = None, parent: Optional[str] = None, required: bool = False + self, path: Union[str, Tuple[str, ...]], field_type: IcebergType, doc: Optional[str] = None, required: bool = False ) -> UpdateSchema: """Add a new column to a nested struct or Add a new top-level column. + Because "." may be interpreted as a column path separator or may be used in field names, it + is not allowed to add nested column by passing in a string. To add to nested structures or + to add fields with names that contain "." use a tuple instead to indicate the path. + + If type is a nested type, its field IDs are reassigned when added to the existing schema. + Args: - name: Name for the new column. - type_var: Type for the new column. + path: Name for the new column. + field_type: Type for the new column. doc: Documentation string for the new column. - parent: Name of the parent struct to the column will be added to. required: Whether the new column is required. Returns: - This for method chaining + This for method chaining. """ - if "." in name: - raise ValueError(f"Cannot add column with ambiguous name: {name}") + if isinstance(path, str): + if "." in path: + raise ValueError(f"Cannot add column with ambiguous name: {path}, provide a tuple instead") + path = (path,) if required and not self._allow_incompatible_changes: # Table format version 1 and 2 cannot add required column because there is no initial value - raise ValueError(f"Incompatible change: cannot add required column: {name}") + raise ValueError(f'Incompatible change: cannot add required column: {".".join(path)}') + + name = path[-1] + parent = path[:-1] + + full_name = ".".join(path) + parent_full_path = ".".join(parent) + parent_id: int = TABLE_ROOT_ID + + if len(parent) > 0: + parent_field = self._schema.find_field(parent_full_path, self._case_sensitive) + parent_type = parent_field.field_type + if isinstance(parent_type, MapType): + parent_field = parent_type.value_field + elif isinstance(parent_type, ListType): + parent_field = parent_type.element_field + + if not parent_field.field_type.is_struct: + raise ValueError(f"Cannot add column '{name}' to non-struct type: {parent_full_path}") + + parent_id = parent_field.field_id + + existing_field = None + try: + existing_field = self._schema.find_field(full_name, self._case_sensitive) + except ValueError: + pass + + if existing_field is not None and existing_field.field_id not in self._deletes: + raise ValueError(f"Cannot add column, name already exists: {full_name}") + + # assign new IDs in order + new_id = self.assign_new_column_id() + + # update tracking for moves + self._added_name_to_id[full_name] = new_id + self._id_to_parent[new_id] = parent_full_path + + new_type = assign_fresh_schema_ids(field_type, self.assign_new_column_id) + field = NestedField(field_id=new_id, name=name, field_type=new_type, required=required, doc=doc) + + if parent_id in self._adds: + self._adds[parent_id].append(field) + else: + self._adds[parent_id] = [field] - self._internal_add_column(parent, name, not required, type_var, doc) return self - def allow_incompatible_changes(self) -> UpdateSchema: - """Allow incompatible changes to the schema. + def delete_column(self, path: Union[str, Tuple[str, ...]]) -> UpdateSchema: + """Delete a column from a table. + + Args: + path: The path to the column. Returns: - This for method chaining + The UpdateSchema with the delete operation staged. """ - self._allow_incompatible_changes = True + name = (path,) if isinstance(path, str) else path + full_name = ".".join(name) + + field = self._schema.find_field(full_name, case_sensitive=self._case_sensitive) + + if field.field_id in self._adds: + raise ValueError(f"Cannot delete a column that has additions: {full_name}") + if field.field_id in self._updates: + raise ValueError(f"Cannot delete a column that has updates: {full_name}") + + self._deletes.add(field.field_id) + return self - def commit(self) -> None: - """Apply the pending changes and commit.""" - new_schema = self._apply() - updates = [ - AddSchemaUpdate(schema=new_schema, last_column_id=new_schema.highest_field_id), - SetCurrentSchemaUpdate(schema_id=-1), - ] - requirements = [AssertCurrentSchemaId(current_schema_id=self._schema.schema_id)] + def rename_column(self, path_from: Union[str, Tuple[str, ...]], new_name: str) -> UpdateSchema: + """Update the name of a column. - if self._transaction is not None: - self._transaction._append_updates(*updates) # pylint: disable=W0212 - self._transaction._append_requirements(*requirements) # pylint: disable=W0212 + Args: + path_from: The path to the column to be renamed. + new_name: The new path of the column. + + Returns: + The UpdateSchema with the rename operation staged. + """ + path_from = ".".join(path_from) if isinstance(path_from, tuple) else path_from + field_from = self._schema.find_field(path_from, self._case_sensitive) + + if field_from.field_id in self._deletes: + raise ValueError(f"Cannot rename a column that will be deleted: {path_from}") + + if updated := self._updates.get(field_from.field_id): + self._updates[field_from.field_id] = NestedField( + field_id=updated.field_id, + name=new_name, + field_type=updated.field_type, + doc=updated.doc, + required=updated.required, + ) else: - table_update_response = self._table.catalog._commit_table( # pylint: disable=W0212 - CommitTableRequest(identifier=self._table.identifier[1:], updates=updates, requirements=requirements) + self._updates[field_from.field_id] = NestedField( + field_id=field_from.field_id, + name=new_name, + field_type=field_from.field_type, + doc=field_from.doc, + required=field_from.required, ) - self._table.metadata = table_update_response.metadata - self._table.metadata_location = table_update_response.metadata_location - def _apply(self) -> Schema: - """Apply the pending changes to the original schema and returns the result. + # Lookup the field because of casing + from_field_correct_casing = self._schema.find_column_name(field_from.field_id) + if from_field_correct_casing in self._identifier_field_names: + self._identifier_field_names.remove(from_field_correct_casing) + new_identifier_path = f"{from_field_correct_casing[:-len(field_from.name)]}{new_name}" + self._identifier_field_names.add(new_identifier_path) + + return self + + def make_column_optional(self, path: Union[str, Tuple[str, ...]]) -> UpdateSchema: + """Make a column optional. + + Args: + path: The path to the field. Returns: - the result Schema when all pending updates are applied + The UpdateSchema with the requirement change staged. """ - return _apply_changes(self._schema, self._adds, self._identifier_field_names) + self._set_column_requirement(path, required=False) + return self - def _internal_add_column( - self, parent: Optional[str], name: str, is_optional: bool, type_var: IcebergType, doc: Optional[str] - ) -> None: - full_name: str = name - parent_id: int = TABLE_ROOT_ID + def set_identifier_fields(self, *fields: str) -> None: + self._identifier_field_names = set(fields) - exist_field: Optional[NestedField] = None - if parent: - parent_field = self._schema.find_field(parent, self._case_sensitive) - parent_type = parent_field.field_type - if isinstance(parent_type, MapType): - parent_field = parent_type.value_field - elif isinstance(parent_type, ListType): - parent_field = parent_type.element_field + def _set_column_requirement(self, path: Union[str, Tuple[str, ...]], required: bool) -> None: + path = (path,) if isinstance(path, str) else path + name = ".".join(path) - if not parent_field.field_type.is_struct: - raise ValueError(f"Cannot add column to non-struct type: {parent}") + field = self._schema.find_field(name, self._case_sensitive) - parent_id = parent_field.field_id + if (field.required and required) or (field.optional and not required): + # if the change is a noop, allow it even if allowIncompatibleChanges is false + return - try: - exist_field = self._schema.find_field(parent + "." + name, self._case_sensitive) - except ValueError: - pass + if not self._allow_incompatible_changes and required: + raise ValueError(f"Cannot change column nullability: {name}: optional -> required") - if exist_field: - raise ValueError(f"Cannot add column, name already exists: {parent}.{name}") + if field.field_id in self._deletes: + raise ValueError(f"Cannot update a column that will be deleted: {name}") - full_name = parent_field.name + "." + name + if updated := self._updates.get(field.field_id): + self._updates[field.field_id] = NestedField( + field_id=updated.field_id, + name=updated.name, + field_type=updated.field_type, + doc=updated.doc, + required=required, + ) + else: + self._updates[field.field_id] = NestedField( + field_id=field.field_id, + name=field.name, + field_type=field.field_type, + doc=field.doc, + required=required, + ) + + def update_column( + self, + path: Union[str, Tuple[str, ...]], + field_type: Optional[IcebergType] = None, + required: Optional[bool] = None, + doc: Optional[str] = None, + ) -> UpdateSchema: + """Update the type of column. + + Args: + path: The path to the field. + field_type: The new type + required: If the field should be required + doc: Documentation describing the column + Returns: + The UpdateSchema with the type update staged. + """ + path = (path,) if isinstance(path, str) else path + full_name = ".".join(path) + + if field_type is None and required is None and doc is None: + return self + + field = self._schema.find_field(full_name, self._case_sensitive) + + if field.field_id in self._deletes: + raise ValueError(f"Cannot update a column that will be deleted: {full_name}") + + if field_type is not None: + if not field.field_type.is_primitive: + raise ValidationError(f"Cannot change column type: {field.field_type} is not a primitive") + + if not self._allow_incompatible_changes and field.field_type != field_type: + try: + promote(field.field_type, field_type) + except ResolveError as e: + raise ValidationError(f"Cannot change column type: {full_name}: {field.field_type} -> {field_type}") from e + + if updated := self._updates.get(field.field_id): + self._updates[field.field_id] = NestedField( + field_id=updated.field_id, + name=updated.name, + field_type=field_type or updated.field_type, + doc=doc or updated.doc, + required=updated.required, + ) else: - try: - exist_field = self._schema.find_field(name, self._case_sensitive) - except ValueError: - pass + self._updates[field.field_id] = NestedField( + field_id=field.field_id, + name=field.name, + field_type=field_type or field.field_type, + doc=doc or field.doc, + required=field.required, + ) - if exist_field: - raise ValueError(f"Cannot add column, name already exists: {name}") + if required is not None: + self._set_column_requirement(path, required=required) - # assign new IDs in order - new_id = self.assign_new_column_id() + return self - # update tracking for moves - self._added_name_to_id[full_name] = new_id + def _find_for_move(self, name: str) -> Optional[int]: + try: + return self._schema.find_field(name, self._case_sensitive).field_id + except ValueError: + pass + + return self._added_name_to_id.get(name) + + def _move(self, move: Move) -> None: + if parent_name := self._id_to_parent.get(move.field_id): + parent_field = self._schema.find_field(parent_name, case_sensitive=self._case_sensitive) + if not parent_field.field_type.is_struct: + raise ValueError(f"Cannot move fields in non-struct type: {parent_field.field_type}") + + if move.op == MoveOperation.After or move.op == MoveOperation.Before: + if move.other_field_id is None: + raise ValueError("Expected other field when performing before/after move") + + if self._id_to_parent.get(move.field_id) != self._id_to_parent.get(move.other_field_id): + raise ValueError(f"Cannot move field {move.full_name} to a different struct") + + self._moves[parent_field.field_id] = self._moves.get(parent_field.field_id, []) + [move] + else: + # In the top level field + if move.op == MoveOperation.After or move.op == MoveOperation.Before: + if move.other_field_id is None: + raise ValueError("Expected other field when performing before/after move") + + if other_struct := self._id_to_parent.get(move.other_field_id): + raise ValueError(f"Cannot move field {move.full_name} to a different struct: {other_struct}") + + self._moves[TABLE_ROOT_ID] = self._moves.get(TABLE_ROOT_ID, []) + [move] + + def move_first(self, path: Union[str, Tuple[str, ...]]) -> UpdateSchema: + """Move the field to the first position of the parent struct. + + Args: + path: The path to the field. + + Returns: + The UpdateSchema with the move operation staged. + """ + full_name = ".".join(path) if isinstance(path, tuple) else path + + field_id = self._find_for_move(full_name) + + if field_id is None: + raise ValueError(f"Cannot move missing column: {full_name}") + + self._move(Move(field_id=field_id, full_name=full_name, op=MoveOperation.First)) + + return self + + def move_before(self, path: Union[str, Tuple[str, ...]], before_path: Union[str, Tuple[str, ...]]) -> UpdateSchema: + """Move the field to before another field. + + Args: + path: The path to the field. + + Returns: + The UpdateSchema with the move operation staged. + """ + full_name = ".".join(path) if isinstance(path, tuple) else path + field_id = self._find_for_move(full_name) + + if field_id is None: + raise ValueError(f"Cannot move missing column: {full_name}") + + before_full_name = ( + ".".join( + before_path, + ) + if isinstance(before_path, tuple) + else before_path + ) + before_field_id = self._find_for_move(before_full_name) + + if before_field_id is None: + raise ValueError(f"Cannot move {full_name} before missing column: {before_full_name}") + + if field_id == before_field_id: + raise ValueError(f"Cannot move {full_name} before itself") + + self._move(Move(field_id=field_id, full_name=full_name, other_field_id=before_field_id, op=MoveOperation.Before)) + + return self + + def move_after(self, path: Union[str, Tuple[str, ...]], after_name: Union[str, Tuple[str, ...]]) -> UpdateSchema: + """Move the field to after another field. + + Args: + path: The path to the field. + + Returns: + The UpdateSchema with the move operation staged. + """ + full_name = ".".join(path) if isinstance(path, tuple) else path + + field_id = self._find_for_move(full_name) + + if field_id is None: + raise ValueError(f"Cannot move missing column: {full_name}") + + after_path = ".".join(after_name) if isinstance(after_name, tuple) else after_name + after_field_id = self._find_for_move(after_path) - new_type = assign_fresh_schema_ids(type_var, self.assign_new_column_id) - field = NestedField(new_id, name, new_type, not is_optional, doc) + if after_field_id is None: + raise ValueError(f"Cannot move {full_name} after missing column: {after_path}") - self._adds.setdefault(parent_id, []).append(field) + if field_id == after_field_id: + raise ValueError(f"Cannot move {full_name} after itself") + + self._move(Move(field_id=field_id, full_name=full_name, other_field_id=after_field_id, op=MoveOperation.After)) + + return self + + def commit(self) -> None: + """Apply the pending changes and commit.""" + new_schema = self._apply() + + if new_schema != self._schema: + last_column_id = max(self._table.metadata.last_column_id, new_schema.highest_field_id) + updates = ( + AddSchemaUpdate(schema=new_schema, last_column_id=last_column_id), + SetCurrentSchemaUpdate(schema_id=-1), + ) + requirements = (AssertCurrentSchemaId(current_schema_id=self._schema.schema_id),) + + if self._transaction is not None: + self._transaction._append_updates(*updates) # pylint: disable=W0212 + self._transaction._append_requirements(*requirements) # pylint: disable=W0212 + else: + self._table._do_commit(updates=updates, requirements=requirements) # pylint: disable=W0212 + + def _apply(self) -> Schema: + """Apply the pending changes to the original schema and returns the result. + + Returns: + the result Schema when all pending updates are applied + """ + struct = visit(self._schema, _ApplyChanges(self._adds, self._updates, self._deletes, self._moves)) + if struct is None: + # Should never happen + raise ValueError("Could not apply changes") + + # Check the field-ids + new_schema = Schema(*struct.fields) + field_ids = set() + for name in self._identifier_field_names: + try: + field = new_schema.find_field(name, case_sensitive=self._case_sensitive) + except ValueError as e: + raise ValueError( + f"Cannot find identifier field {name}. In case of deletion, update the identifier fields first." + ) from e + + field_ids.add(field.field_id) + + return Schema(*struct.fields, schema_id=1 + max(self._table.schemas().keys()), identifier_field_ids=field_ids) def assign_new_column_id(self) -> int: return next(self._last_column_id) -def _apply_changes(schema_: Schema, adds: Dict[int, List[NestedField]], identifier_field_names: List[str]) -> Schema: - struct = visit(schema_, _ApplyChanges(adds)) - name_to_id: Dict[str, int] = index_by_name(struct) - for name in identifier_field_names: - if name not in name_to_id: - raise ValueError(f"Cannot add field {name} as an identifier field: not found in current schema or added columns") +class _ApplyChanges(SchemaVisitor[Optional[IcebergType]]): + _adds: Dict[int, List[NestedField]] + _updates: Dict[int, NestedField] + _deletes: Set[int] + _moves: Dict[int, List[Move]] - return Schema(*struct.fields) + def __init__( + self, adds: Dict[int, List[NestedField]], updates: Dict[int, NestedField], deletes: Set[int], moves: Dict[int, List[Move]] + ) -> None: + self._adds = adds + self._updates = updates + self._deletes = deletes + self._moves = moves + def schema(self, schema: Schema, struct_result: Optional[IcebergType]) -> Optional[IcebergType]: + added = self._adds.get(TABLE_ROOT_ID) + moves = self._moves.get(TABLE_ROOT_ID) -class _ApplyChanges(SchemaVisitor[IcebergType]): - def __init__(self, adds: Dict[int, List[NestedField]]): - self.adds = adds + if added is not None or moves is not None: + if not isinstance(struct_result, StructType): + raise ValueError(f"Cannot add fields to non-struct: {struct_result}") - def schema(self, schema: Schema, struct_result: IcebergType) -> IcebergType: - fields = _ApplyChanges.add_fields(schema.as_struct().fields, self.adds.get(TABLE_ROOT_ID)) - if len(fields) > 0: - return StructType(*fields) + if new_fields := _add_and_move_fields(struct_result.fields, added or [], moves or []): + return StructType(*new_fields) return struct_result - def struct(self, struct: StructType, field_results: List[IcebergType]) -> IcebergType: - has_change = False - new_fields: List[NestedField] = [] - for i in range(len(field_results)): - type_: Optional[IcebergType] = field_results[i] - if type_ is None: - has_change = True + def struct(self, struct: StructType, field_results: List[Optional[IcebergType]]) -> Optional[IcebergType]: + has_changes = False + new_fields = [] + + for idx, result_type in enumerate(field_results): + result_type = field_results[idx] + + # Has been deleted + if result_type is None: + has_changes = True continue - field: NestedField = struct.fields[i] - new_fields.append(field) + field = struct.fields[idx] + + name = field.name + doc = field.doc + required = field.required - if has_change: + # There is an update + if update := self._updates.get(field.field_id): + name = update.name + doc = update.doc + required = update.required + + if field.name == name and field.field_type == result_type and field.required == required and field.doc == doc: + new_fields.append(field) + else: + has_changes = True + new_fields.append( + NestedField(field_id=field.field_id, name=name, field_type=result_type, required=required, doc=doc) + ) + + if has_changes: return StructType(*new_fields) return struct - def field(self, field: NestedField, field_result: IcebergType) -> IcebergType: - field_id: int = field.field_id - if field_id in self.adds: - new_fields = self.adds[field_id] - if len(new_fields) > 0: - fields = _ApplyChanges.add_fields(field_result.fields, new_fields) - if len(fields) > 0: - return StructType(*fields) + def field(self, field: NestedField, field_result: Optional[IcebergType]) -> Optional[IcebergType]: + # the API validates deletes, updates, and additions don't conflict handle deletes + if field.field_id in self._deletes: + return None + + # handle updates + if (update := self._updates.get(field.field_id)) and field.field_type != update.field_type: + return update.field_type + + if isinstance(field_result, StructType): + # handle add & moves + added = self._adds.get(field.field_id) + moves = self._moves.get(field.field_id) + if added is not None or moves is not None: + if not isinstance(field.field_type, StructType): + raise ValueError(f"Cannot add fields to non-struct: {field}") + + if new_fields := _add_and_move_fields(field_result.fields, added or [], moves or []): + return StructType(*new_fields) return field_result - def list(self, list_type: ListType, element_result: IcebergType) -> IcebergType: - element_field: NestedField = list_type.element_field - element_type = self.field(element_field, element_result) + def list(self, list_type: ListType, element_result: Optional[IcebergType]) -> Optional[IcebergType]: + element_type = self.field(list_type.element_field, element_result) if element_type is None: - raise ValueError(f"Cannot delete element type from list: {element_field}") + raise ValueError(f"Cannot delete element type from list: {element_result}") - is_element_optional = not list_type.element_required + return ListType(element_id=list_type.element_id, element=element_type, element_required=list_type.element_required) - if is_element_optional == element_field.required and list_type.element_type == element_type: - return list_type + def map( + self, map_type: MapType, key_result: Optional[IcebergType], value_result: Optional[IcebergType] + ) -> Optional[IcebergType]: + key_id: int = map_type.key_field.field_id - return ListType(list_type.element_id, element_type, is_element_optional) + if key_id in self._deletes: + raise ValueError(f"Cannot delete map keys: {map_type}") - def map(self, map_type: MapType, key_result: IcebergType, value_result: IcebergType) -> IcebergType: - key_id: int = map_type.key_field.field_id - if key_id in self.adds: + if key_id in self._updates: + raise ValueError(f"Cannot update map keys: {map_type}") + + if key_id in self._adds: raise ValueError(f"Cannot add fields to map keys: {map_type}") + if map_type.key_type != key_result: + raise ValueError(f"Cannot alter map keys: {map_type}") + value_field: NestedField = map_type.value_field value_type = self.field(value_field, value_result) if value_type is None: raise ValueError(f"Cannot delete value type from map: {value_field}") - is_value_optional = not map_type.value_required + return MapType( + key_id=map_type.key_id, + key_type=map_type.key_type, + value_id=map_type.value_id, + value_type=value_type, + value_required=map_type.value_required, + ) + + def primitive(self, primitive: PrimitiveType) -> Optional[IcebergType]: + return primitive - if is_value_optional != value_field.required and map_type.value_type == value_type: - return map_type - return MapType(map_type.key_id, map_type.key_field, map_type.value_id, value_type, not is_value_optional) +def _add_fields(fields: Tuple[NestedField, ...], adds: Optional[List[NestedField]]) -> Tuple[NestedField, ...]: + adds = adds or [] + return fields + tuple(adds) - def primitive(self, primitive: PrimitiveType) -> IcebergType: - return primitive - @staticmethod - def add_fields(fields: Tuple[NestedField, ...], adds: Optional[List[NestedField]]) -> List[NestedField]: - new_fields: List[NestedField] = [] - new_fields.extend(fields) - if adds: - new_fields.extend(adds) - return new_fields +def _move_fields(fields: Tuple[NestedField, ...], moves: List[Move]) -> Tuple[NestedField, ...]: + reordered = list(copy(fields)) + for move in moves: + # Find the field that we're about to move + field = next(field for field in reordered if field.field_id == move.field_id) + # Remove the field that we're about to move from the list + reordered = [field for field in reordered if field.field_id != move.field_id] + + if move.op == MoveOperation.First: + reordered = [field] + reordered + elif move.op == MoveOperation.Before or move.op == MoveOperation.After: + other_field_id = move.other_field_id + other_field_pos = next(i for i, field in enumerate(reordered) if field.field_id == other_field_id) + if move.op == MoveOperation.Before: + reordered.insert(other_field_pos, field) + else: + reordered.insert(other_field_pos + 1, field) + else: + raise ValueError(f"Unknown operation: {move.op}") + + return tuple(reordered) + + +def _add_and_move_fields( + fields: Tuple[NestedField, ...], adds: List[NestedField], moves: List[Move] +) -> Optional[Tuple[NestedField, ...]]: + if len(adds) > 0: + # always apply adds first so that added fields can be moved + added = _add_fields(fields, adds) + if len(moves) > 0: + return _move_fields(added, moves) + else: + return added + elif len(moves) > 0: + return _move_fields(fields, moves) + return None if len(adds) == 0 else tuple(*fields, *adds) diff --git a/python/tests/catalog/test_base.py b/python/tests/catalog/test_base.py index 29e93d0c9d05..e4da808014f6 100644 --- a/python/tests/catalog/test_base.py +++ b/python/tests/catalog/test_base.py @@ -542,7 +542,7 @@ def test_commit_table(catalog: InMemoryCatalog) -> None: def test_add_column(catalog: InMemoryCatalog) -> None: given_table = given_catalog_has_a_table(catalog) - given_table.update_schema().add_column(name="new_column1", type_var=IntegerType()).commit() + given_table.update_schema().add_column(path="new_column1", field_type=IntegerType()).commit() assert given_table.schema() == Schema( NestedField(field_id=1, name="x", field_type=LongType(), required=True), @@ -554,7 +554,7 @@ def test_add_column(catalog: InMemoryCatalog) -> None: ) transaction = given_table.transaction() - transaction.update_schema().add_column(name="new_column2", type_var=IntegerType(), doc="doc").commit() + transaction.update_schema().add_column(path="new_column2", field_type=IntegerType(), doc="doc").commit() transaction.commit_transaction() assert given_table.schema() == Schema( @@ -572,7 +572,7 @@ def test_add_column_with_statement(catalog: InMemoryCatalog) -> None: given_table = given_catalog_has_a_table(catalog) with given_table.update_schema() as tx: - tx.add_column(name="new_column1", type_var=IntegerType()) + tx.add_column(path="new_column1", field_type=IntegerType()) assert given_table.schema() == Schema( NestedField(field_id=1, name="x", field_type=LongType(), required=True), @@ -584,7 +584,7 @@ def test_add_column_with_statement(catalog: InMemoryCatalog) -> None: ) with given_table.transaction() as tx: - tx.update_schema().add_column(name="new_column2", type_var=IntegerType(), doc="doc").commit() + tx.update_schema().add_column(path="new_column2", field_type=IntegerType(), doc="doc").commit() assert given_table.schema() == Schema( NestedField(field_id=1, name="x", field_type=LongType(), required=True), diff --git a/python/tests/table/test_init.py b/python/tests/table/test_init.py index b25e445032fd..3ee0cd37f2d9 100644 --- a/python/tests/table/test_init.py +++ b/python/tests/table/test_init.py @@ -396,23 +396,24 @@ def test_serialize_set_properties_updates() -> None: assert SetPropertiesUpdate(updates={"abc": "🤪"}).model_dump_json() == """{"action":"set-properties","updates":{"abc":"🤪"}}""" -def test_add_column(table_schema_simple: Schema, table: Table) -> None: - update = UpdateSchema(table_schema_simple, table) - update.add_column(name="b", type_var=IntegerType()) +def test_add_column(table: Table) -> None: + update = UpdateSchema(table) + update.add_column(path="b", field_type=IntegerType()) apply_schema: Schema = update._apply() # pylint: disable=W0212 assert len(apply_schema.fields) == 4 assert apply_schema == Schema( - NestedField(field_id=1, name="foo", field_type=StringType(), required=False), - NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True), - NestedField(field_id=3, name="baz", field_type=BooleanType(), required=False), + NestedField(field_id=1, name="x", field_type=LongType(), required=True), + NestedField(field_id=2, name="y", field_type=LongType(), required=True, doc="comment"), + NestedField(field_id=3, name="z", field_type=LongType(), required=True), NestedField(field_id=4, name="b", field_type=IntegerType(), required=False), + identifier_field_ids=[1, 2], ) - assert apply_schema.schema_id == 0 + assert apply_schema.schema_id == 2 assert apply_schema.highest_field_id == 4 -def test_add_primitive_type_column(table_schema_simple: Schema, table: Table) -> None: +def test_add_primitive_type_column(table: Table) -> None: primitive_type: Dict[str, PrimitiveType] = { "boolean": BooleanType(), "int": IntegerType(), @@ -430,8 +431,8 @@ def test_add_primitive_type_column(table_schema_simple: Schema, table: Table) -> for name, type_ in primitive_type.items(): field_name = f"new_column_{name}" - update = UpdateSchema(table_schema_simple, table) - update.add_column(parent=None, name=field_name, type_var=type_, doc=f"new_column_{name}") + update = UpdateSchema(table) + update.add_column(path=field_name, field_type=type_, doc=f"new_column_{name}") new_schema = update._apply() # pylint: disable=W0212 field: NestedField = new_schema.find_field(field_name) @@ -439,15 +440,15 @@ def test_add_primitive_type_column(table_schema_simple: Schema, table: Table) -> assert field.doc == f"new_column_{name}" -def test_add_nested_type_column(table_schema_simple: Schema, table: Table) -> None: +def test_add_nested_type_column(table: Table) -> None: # add struct type column field_name = "new_column_struct" - update = UpdateSchema(table_schema_simple, table) + update = UpdateSchema(table) struct_ = StructType( NestedField(1, "lat", DoubleType()), NestedField(2, "long", DoubleType()), ) - update.add_column(parent=None, name=field_name, type_var=struct_) + update.add_column(path=field_name, field_type=struct_) schema_ = update._apply() # pylint: disable=W0212 field: NestedField = schema_.find_field(field_name) assert field.field_type == StructType( @@ -457,22 +458,22 @@ def test_add_nested_type_column(table_schema_simple: Schema, table: Table) -> No assert schema_.highest_field_id == 6 -def test_add_nested_map_type_column(table_schema_simple: Schema, table: Table) -> None: +def test_add_nested_map_type_column(table: Table) -> None: # add map type column field_name = "new_column_map" - update = UpdateSchema(table_schema_simple, table) + update = UpdateSchema(table) map_ = MapType(1, StringType(), 2, IntegerType(), False) - update.add_column(parent=None, name=field_name, type_var=map_) + update.add_column(path=field_name, field_type=map_) new_schema = update._apply() # pylint: disable=W0212 field: NestedField = new_schema.find_field(field_name) assert field.field_type == MapType(5, StringType(), 6, IntegerType(), False) assert new_schema.highest_field_id == 6 -def test_add_nested_list_type_column(table_schema_simple: Schema, table: Table) -> None: +def test_add_nested_list_type_column(table: Table) -> None: # add list type column field_name = "new_column_list" - update = UpdateSchema(table_schema_simple, table) + update = UpdateSchema(table) list_ = ListType( element_id=101, element_type=StructType( @@ -481,7 +482,7 @@ def test_add_nested_list_type_column(table_schema_simple: Schema, table: Table) ), element_required=False, ) - update.add_column(parent=None, name=field_name, type_var=list_) + update.add_column(path=field_name, field_type=list_) new_schema = update._apply() # pylint: disable=W0212 field: NestedField = new_schema.find_field(field_name) assert field.field_type == ListType( @@ -493,77 +494,3 @@ def test_add_nested_list_type_column(table_schema_simple: Schema, table: Table) element_required=False, ) assert new_schema.highest_field_id == 7 - - -def test_add_field_to_map_key(table_schema_nested_with_struct_key_map: Schema, table: Table) -> None: - with pytest.raises(ValueError) as exc_info: - update = UpdateSchema(table_schema_nested_with_struct_key_map, table) - update.add_column(name="b", type_var=IntegerType(), parent="location.key")._apply() # pylint: disable=W0212 - assert "Cannot add fields to map keys" in str(exc_info.value) - - -def test_add_already_exists(table_schema_nested: Schema, table: Table) -> None: - with pytest.raises(ValueError) as exc_info: - update = UpdateSchema(table_schema_nested, table) - update.add_column("foo", IntegerType()) - assert "already exists: foo" in str(exc_info.value) - - with pytest.raises(ValueError) as exc_info: - update = UpdateSchema(table_schema_nested, table) - update.add_column(name="latitude", type_var=IntegerType(), parent="location") - assert "already exists: location.lat" in str(exc_info.value) - - -def test_add_to_non_struct_type(table_schema_simple: Schema, table: Table) -> None: - with pytest.raises(ValueError) as exc_info: - update = UpdateSchema(table_schema_simple, table) - update.add_column(name="lat", type_var=IntegerType(), parent="foo") - assert "Cannot add column to non-struct type" in str(exc_info.value) - - -def test_add_required_column(table: Table) -> None: - schema_ = Schema( - NestedField(field_id=1, name="a", field_type=BooleanType(), required=False), schema_id=1, identifier_field_ids=[] - ) - - with pytest.raises(ValueError) as exc_info: - update = UpdateSchema(schema_, table) - update.add_column(name="data", type_var=IntegerType(), required=True) - assert "Incompatible change: cannot add required column: data" in str(exc_info.value) - - new_schema = ( - UpdateSchema(schema_, table) # pylint: disable=W0212 - .allow_incompatible_changes() - .add_column(name="data", type_var=IntegerType(), required=True) - ._apply() - ) - assert new_schema == Schema( - NestedField(field_id=1, name="a", field_type=BooleanType(), required=False), - NestedField(field_id=2, name="data", field_type=IntegerType(), required=True), - schema_id=0, - identifier_field_ids=[], - ) - - -def test_add_required_column_case_insensitive(table: Table) -> None: - schema_ = Schema( - NestedField(field_id=1, name="id", field_type=BooleanType(), required=False), schema_id=1, identifier_field_ids=[] - ) - - with pytest.raises(ValueError) as exc_info: - update = UpdateSchema(schema_, table) - update.allow_incompatible_changes().case_sensitive(False).add_column(name="ID", type_var=IntegerType(), required=True) - assert "already exists: ID" in str(exc_info.value) - - new_schema = ( - UpdateSchema(schema_, table) # pylint: disable=W0212 - .allow_incompatible_changes() - .add_column(name="ID", type_var=IntegerType(), required=True) - ._apply() - ) - assert new_schema == Schema( - NestedField(field_id=1, name="id", field_type=BooleanType(), required=False), - NestedField(field_id=2, name="ID", field_type=IntegerType(), required=True), - schema_id=0, - identifier_field_ids=[], - ) diff --git a/python/tests/test_integration.py b/python/tests/test_integration.py index acd694677463..a63436bdaead 100644 --- a/python/tests/test_integration.py +++ b/python/tests/test_integration.py @@ -25,7 +25,7 @@ from pyarrow.fs import S3FileSystem from pyiceberg.catalog import Catalog, load_catalog -from pyiceberg.exceptions import CommitFailedException, NoSuchTableError +from pyiceberg.exceptions import NoSuchTableError from pyiceberg.expressions import ( And, EqualTo, @@ -40,14 +40,10 @@ from pyiceberg.table import Table from pyiceberg.types import ( BooleanType, - DoubleType, - FixedType, IntegerType, - LongType, NestedField, StringType, TimestampType, - UUIDType, ) @@ -356,89 +352,3 @@ def test_unpartitioned_fixed_table(catalog: Catalog) -> None: b"12345678901234567ass12345", b"qweeqwwqq1231231231231111", ] - - -@pytest.mark.integration -def test_schema_evolution(catalog: Catalog) -> None: - try: - catalog.drop_table("default.test_schema_evolution") - except NoSuchTableError: - pass - - schema = Schema( - NestedField(field_id=1, name="col_uuid", field_type=UUIDType(), required=False), - NestedField(field_id=2, name="col_fixed", field_type=FixedType(25), required=False), - ) - - t = catalog.create_table(identifier="default.test_schema_evolution", schema=schema) - - assert t.schema() == schema - - with t.update_schema() as tx: - tx.add_column("col_string", StringType()) - - assert t.schema() == Schema( - NestedField(field_id=1, name="col_uuid", field_type=UUIDType(), required=False), - NestedField(field_id=2, name="col_fixed", field_type=FixedType(25), required=False), - NestedField(field_id=3, name="col_string", field_type=StringType(), required=False), - schema_id=1, - ) - - -@pytest.mark.integration -def test_schema_evolution_via_transaction(catalog: Catalog) -> None: - try: - catalog.drop_table("default.test_schema_evolution") - except NoSuchTableError: - pass - - schema = Schema( - NestedField(field_id=1, name="col_uuid", field_type=UUIDType(), required=False), - NestedField(field_id=2, name="col_fixed", field_type=FixedType(25), required=False), - ) - - tbl = catalog.create_table(identifier="default.test_schema_evolution", schema=schema) - - assert tbl.schema() == schema - - with tbl.transaction() as tx: - tx.update_schema().add_column("col_string", StringType()).commit() - - assert tbl.schema() == Schema( - NestedField(field_id=1, name="col_uuid", field_type=UUIDType(), required=False), - NestedField(field_id=2, name="col_fixed", field_type=FixedType(25), required=False), - NestedField(field_id=3, name="col_string", field_type=StringType(), required=False), - schema_id=1, - ) - - tbl.update_schema().add_column("col_integer", IntegerType()).commit() - - assert tbl.schema() == Schema( - NestedField(field_id=1, name="col_uuid", field_type=UUIDType(), required=False), - NestedField(field_id=2, name="col_fixed", field_type=FixedType(25), required=False), - NestedField(field_id=3, name="col_string", field_type=StringType(), required=False), - NestedField(field_id=4, name="col_integer", field_type=IntegerType(), required=False), - schema_id=1, - ) - - with pytest.raises(CommitFailedException) as exc_info: - with tbl.transaction() as tx: - # Start a new update - schema_update = tx.update_schema() - - # Do a concurrent update - tbl.update_schema().add_column("col_long", LongType()).commit() - - # stage another update in the transaction - schema_update.add_column("col_double", DoubleType()).commit() - - assert "Requirement failed: current schema changed: expected id 2 != 3" in str(exc_info.value) - - assert tbl.schema() == Schema( - NestedField(field_id=1, name="col_uuid", field_type=UUIDType(), required=False), - NestedField(field_id=2, name="col_fixed", field_type=FixedType(25), required=False), - NestedField(field_id=3, name="col_string", field_type=StringType(), required=False), - NestedField(field_id=4, name="col_integer", field_type=IntegerType(), required=False), - NestedField(field_id=5, name="col_long", field_type=LongType(), required=False), - schema_id=1, - ) diff --git a/python/tests/test_integration_schema.py b/python/tests/test_integration_schema.py new file mode 100644 index 000000000000..f0ccb1b0e858 --- /dev/null +++ b/python/tests/test_integration_schema.py @@ -0,0 +1,2471 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint:disable=redefined-outer-name + +import pytest + +from pyiceberg.catalog import Catalog, load_catalog +from pyiceberg.exceptions import CommitFailedException, NoSuchTableError, ValidationError +from pyiceberg.schema import Schema, prune_columns +from pyiceberg.table import Table, UpdateSchema +from pyiceberg.types import ( + BinaryType, + BooleanType, + DateType, + DecimalType, + DoubleType, + FixedType, + FloatType, + IntegerType, + ListType, + LongType, + MapType, + NestedField, + PrimitiveType, + StringType, + StructType, + TimestampType, + TimestamptzType, + TimeType, + UUIDType, +) + + +@pytest.fixture() +def catalog() -> Catalog: + return load_catalog( + "local", + **{ + "type": "rest", + "uri": "http://localhost:8181", + "s3.endpoint": "http://localhost:9000", + "s3.access-key-id": "admin", + "s3.secret-access-key": "password", + }, + ) + + +@pytest.fixture() +def simple_table(catalog: Catalog, table_schema_simple: Schema) -> Table: + return _create_table_with_schema(catalog, table_schema_simple) + + +def _create_table_with_schema(catalog: Catalog, schema: Schema) -> Table: + tbl_name = "default.test_schema_evolution" + try: + catalog.drop_table(tbl_name) + except NoSuchTableError: + pass + return catalog.create_table(identifier=tbl_name, schema=schema) + + +@pytest.mark.integration +def test_add_already_exists(catalog: Catalog, table_schema_nested: Schema) -> None: + table = _create_table_with_schema(catalog, table_schema_nested) + update = UpdateSchema(table) + + with pytest.raises(ValueError) as exc_info: + update.add_column("foo", IntegerType()) + assert "already exists: foo" in str(exc_info.value) + + with pytest.raises(ValueError) as exc_info: + update.add_column(path=("location", "latitude"), field_type=IntegerType()) + assert "already exists: location.latitude" in str(exc_info.value) + + +@pytest.mark.integration +def test_add_to_non_struct_type(catalog: Catalog, table_schema_simple: Schema) -> None: + table = _create_table_with_schema(catalog, table_schema_simple) + update = UpdateSchema(table) + with pytest.raises(ValueError) as exc_info: + update.add_column(path=("foo", "lat"), field_type=IntegerType()) + assert "Cannot add column 'lat' to non-struct type: foo" in str(exc_info.value) + + +@pytest.mark.integration +def test_schema_evolution_nested_field(catalog: Catalog) -> None: + schema = Schema( + NestedField( + field_id=1, + name="foo", + field_type=StructType(NestedField(2, name="bar", field_type=StringType(), required=False)), + required=False, + ), + ) + tbl = _create_table_with_schema(catalog, schema) + + assert tbl.schema() == schema + + with pytest.raises(ValidationError) as exc_info: + with tbl.transaction() as tx: + tx.update_schema().update_column("foo", StringType()).commit() + + assert "Cannot change column type: struct<2: bar: optional string> is not a primitive" in str(exc_info.value) + + +@pytest.mark.integration +def test_schema_evolution_via_transaction(catalog: Catalog) -> None: + schema = Schema( + NestedField(field_id=1, name="col_uuid", field_type=UUIDType(), required=False), + NestedField(field_id=2, name="col_fixed", field_type=FixedType(25), required=False), + ) + tbl = _create_table_with_schema(catalog, schema) + + assert tbl.schema() == schema + + with tbl.transaction() as tx: + tx.update_schema().add_column("col_string", StringType()).commit() + + assert tbl.schema() == Schema( + NestedField(field_id=1, name="col_uuid", field_type=UUIDType(), required=False), + NestedField(field_id=2, name="col_fixed", field_type=FixedType(25), required=False), + NestedField(field_id=3, name="col_string", field_type=StringType(), required=False), + ) + + tbl.update_schema().add_column("col_integer", IntegerType()).commit() + + assert tbl.schema() == Schema( + NestedField(field_id=1, name="col_uuid", field_type=UUIDType(), required=False), + NestedField(field_id=2, name="col_fixed", field_type=FixedType(25), required=False), + NestedField(field_id=3, name="col_string", field_type=StringType(), required=False), + NestedField(field_id=4, name="col_integer", field_type=IntegerType(), required=False), + ) + + with pytest.raises(CommitFailedException) as exc_info: + with tbl.transaction() as tx: + # Start a new update + schema_update = tx.update_schema() + + # Do a concurrent update + tbl.update_schema().add_column("col_long", LongType()).commit() + + # stage another update in the transaction + schema_update.add_column("col_double", DoubleType()).commit() + + assert "Requirement failed: current schema changed: expected id 2 != 3" in str(exc_info.value) + + assert tbl.schema() == Schema( + NestedField(field_id=1, name="col_uuid", field_type=UUIDType(), required=False), + NestedField(field_id=2, name="col_fixed", field_type=FixedType(25), required=False), + NestedField(field_id=3, name="col_string", field_type=StringType(), required=False), + NestedField(field_id=4, name="col_integer", field_type=IntegerType(), required=False), + NestedField(field_id=5, name="col_long", field_type=LongType(), required=False), + ) + + +@pytest.mark.integration +def test_schema_evolution_nested(catalog: Catalog) -> None: + nested_schema = Schema( + NestedField( + field_id=1, + name="location_lookup", + field_type=MapType( + key_id=10, + key_type=StringType(), + value_id=11, + value_type=StructType( + NestedField(field_id=110, name="x", field_type=FloatType(), required=False), + NestedField(field_id=111, name="y", field_type=FloatType(), required=False), + ), + element_required=True, + ), + required=True, + ), + NestedField( + field_id=2, + name="locations", + field_type=ListType( + element_id=20, + element_type=StructType( + NestedField(field_id=200, name="x", field_type=FloatType(), required=False), + NestedField(field_id=201, name="y", field_type=FloatType(), required=False), + ), + element_required=True, + ), + required=True, + ), + NestedField( + field_id=3, + name="person", + field_type=StructType( + NestedField(field_id=30, name="name", field_type=StringType(), required=False), + NestedField(field_id=31, name="age", field_type=IntegerType(), required=True), + ), + required=False, + ), + ) + + tbl = _create_table_with_schema(catalog, nested_schema) + + assert tbl.schema().highest_field_id == 12 + + with tbl.update_schema() as schema_update: + schema_update.add_column(("location_lookup", "z"), FloatType()) + schema_update.add_column(("locations", "z"), FloatType()) + schema_update.add_column(("person", "address"), StringType()) + + assert str(tbl.schema()) == str( + Schema( + NestedField( + field_id=1, + name="location_lookup", + field_type=MapType( + type="map", + key_id=4, + key_type=StringType(), + value_id=5, + value_type=StructType( + NestedField(field_id=6, name="x", field_type=FloatType(), required=False), + NestedField(field_id=7, name="y", field_type=FloatType(), required=False), + NestedField(field_id=13, name="z", field_type=FloatType(), required=False), + ), + value_required=True, + ), + required=True, + ), + NestedField( + field_id=2, + name="locations", + field_type=ListType( + type="list", + element_id=8, + element_type=StructType( + NestedField(field_id=9, name="x", field_type=FloatType(), required=False), + NestedField(field_id=10, name="y", field_type=FloatType(), required=False), + NestedField(field_id=14, name="z", field_type=FloatType(), required=False), + ), + element_required=True, + ), + required=True, + ), + NestedField( + field_id=3, + name="person", + field_type=StructType( + NestedField(field_id=11, name="name", field_type=StringType(), required=False), + NestedField(field_id=12, name="age", field_type=IntegerType(), required=True), + NestedField(field_id=15, name="address", field_type=StringType(), required=False), + ), + required=False, + ), + ) + ) + + +schema_nested = Schema( + NestedField(field_id=1, name="foo", field_type=StringType(), required=False), + NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True), + NestedField(field_id=3, name="baz", field_type=BooleanType(), required=False), + NestedField( + field_id=4, + name="qux", + field_type=ListType(type="list", element_id=8, element_type=StringType(), element_required=True), + required=True, + ), + NestedField( + field_id=5, + name="quux", + field_type=MapType( + type="map", + key_id=9, + key_type=StringType(), + value_id=10, + value_type=MapType( + type="map", key_id=11, key_type=StringType(), value_id=12, value_type=IntegerType(), value_required=True + ), + value_required=True, + ), + required=True, + ), + NestedField( + field_id=6, + name="location", + field_type=ListType( + type="list", + element_id=13, + element_type=StructType( + NestedField(field_id=14, name="latitude", field_type=FloatType(), required=False), + NestedField(field_id=15, name="longitude", field_type=FloatType(), required=False), + ), + element_required=True, + ), + required=True, + ), + NestedField( + field_id=7, + name="person", + field_type=StructType( + NestedField(field_id=16, name="name", field_type=StringType(), required=False), + NestedField(field_id=17, name="age", field_type=IntegerType(), required=True), + ), + required=False, + ), + identifier_field_ids=[2], +) + + +@pytest.fixture() +def nested_table(catalog: Catalog) -> Table: + return _create_table_with_schema(catalog, schema_nested) + + +@pytest.mark.integration +def test_no_changes(simple_table: Table, table_schema_simple: Schema) -> None: + with simple_table.update_schema() as _: + pass + + assert simple_table.schema() == table_schema_simple + + +@pytest.mark.integration +def test_no_changes_empty_commit(simple_table: Table, table_schema_simple: Schema) -> None: + with simple_table.update_schema() as update: + # No updates, so this should be a noop + update.update_column(path="foo") + + assert simple_table.schema() == table_schema_simple + + +@pytest.mark.integration +def test_delete_field(simple_table: Table) -> None: + with simple_table.update_schema() as schema_update: + schema_update.delete_column("foo") + + assert simple_table.schema() == Schema( + # foo is missing 👍 + NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True), + NestedField(field_id=3, name="baz", field_type=BooleanType(), required=False), + identifier_field_ids=[2], + ) + + +@pytest.mark.integration +def test_delete_field_case_insensitive(simple_table: Table) -> None: + with simple_table.update_schema(case_sensitive=False) as schema_update: + schema_update.delete_column("FOO") + + assert simple_table.schema() == Schema( + # foo is missing 👍 + NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True), + NestedField(field_id=3, name="baz", field_type=BooleanType(), required=False), + identifier_field_ids=[2], + ) + + +@pytest.mark.integration +def test_delete_identifier_fields(simple_table: Table) -> None: + with pytest.raises(ValueError) as exc_info: + with simple_table.update_schema() as schema_update: + schema_update.delete_column("bar") + + assert "Cannot find identifier field bar. In case of deletion, update the identifier fields first." in str(exc_info) + + +@pytest.mark.integration +def test_delete_identifier_fields_nested(catalog: Catalog) -> None: + tbl = _create_table_with_schema( + catalog, + Schema( + NestedField(field_id=1, name="foo", field_type=StringType(), required=False), + NestedField( + field_id=2, + name="person", + field_type=StructType( + NestedField(field_id=3, name="name", field_type=StringType(), required=True), + NestedField(field_id=4, name="age", field_type=IntegerType(), required=True), + ), + required=True, + ), + identifier_field_ids=[3], + ), + ) + + with pytest.raises(ValueError) as exc_info: + with tbl.update_schema() as schema_update: + schema_update.delete_column("person") + + assert "Cannot find identifier field person.name. In case of deletion, update the identifier fields first." in str(exc_info) + + +@pytest.mark.parametrize( + "field", + [ + "foo", + "baz", + "qux", + "quux", + "location", + "location.element.latitude", + "location.element.longitude", + "person", + "person.name", + "person.age", + ], +) +@pytest.mark.integration +def test_deletes(field: str, nested_table: Table) -> None: + with nested_table.update_schema() as schema_update: + schema_update.delete_column(field) + + selected_ids = { + field_id + for field_id in schema_nested.field_ids + if not isinstance(schema_nested.find_field(field_id).field_type, (MapType, ListType)) + and not schema_nested.find_column_name(field_id).startswith(field) # type: ignore + } + expected_schema = prune_columns(schema_nested, selected_ids, select_full_types=False) + + assert expected_schema == nested_table.schema() + + +@pytest.mark.parametrize( + "field", + [ + "Foo", + "Baz", + "Qux", + "Quux", + "Location", + "Location.element.latitude", + "Location.element.longitude", + "Person", + "Person.name", + "Person.age", + ], +) +@pytest.mark.integration +def test_deletes_case_insensitive(field: str, nested_table: Table) -> None: + with nested_table.update_schema(case_sensitive=False) as schema_update: + schema_update.delete_column(field) + + selected_ids = { + field_id + for field_id in schema_nested.field_ids + if not isinstance(schema_nested.find_field(field_id).field_type, (MapType, ListType)) + and not schema_nested.find_column_name(field_id).startswith(field.lower()) # type: ignore + } + expected_schema = prune_columns(schema_nested, selected_ids, select_full_types=False) + + assert expected_schema == nested_table.schema() + + +@pytest.mark.integration +def test_update_types(catalog: Catalog) -> None: + tbl = _create_table_with_schema( + catalog, + Schema( + NestedField(field_id=1, name="bar", field_type=IntegerType(), required=True), + NestedField( + field_id=2, + name="location", + field_type=ListType( + type="list", + element_id=3, + element_type=StructType( + NestedField(field_id=4, name="latitude", field_type=FloatType(), required=False), + NestedField(field_id=5, name="longitude", field_type=FloatType(), required=False), + ), + element_required=True, + ), + required=True, + ), + ), + ) + + with tbl.update_schema() as schema_update: + schema_update.update_column("bar", LongType()) + schema_update.update_column("location.latitude", DoubleType()) + schema_update.update_column("location.longitude", DoubleType()) + + assert tbl.schema() == Schema( + NestedField(field_id=1, name="bar", field_type=LongType(), required=True), + NestedField( + field_id=2, + name="location", + field_type=ListType( + type="list", + element_id=3, + element_type=StructType( + NestedField(field_id=4, name="latitude", field_type=DoubleType(), required=False), + NestedField(field_id=5, name="longitude", field_type=DoubleType(), required=False), + ), + element_required=True, + ), + required=True, + ), + ) + + +@pytest.mark.integration +def test_update_types_case_insensitive(catalog: Catalog) -> None: + tbl = _create_table_with_schema( + catalog, + Schema( + NestedField(field_id=1, name="bar", field_type=IntegerType(), required=True), + NestedField( + field_id=2, + name="location", + field_type=ListType( + type="list", + element_id=3, + element_type=StructType( + NestedField(field_id=4, name="latitude", field_type=FloatType(), required=False), + NestedField(field_id=5, name="longitude", field_type=FloatType(), required=False), + ), + element_required=True, + ), + required=True, + ), + ), + ) + + with tbl.update_schema(case_sensitive=False) as schema_update: + schema_update.update_column("baR", LongType()) + schema_update.update_column("Location.Latitude", DoubleType()) + schema_update.update_column("Location.Longitude", DoubleType()) + + assert tbl.schema() == Schema( + NestedField(field_id=1, name="bar", field_type=LongType(), required=True), + NestedField( + field_id=2, + name="location", + field_type=ListType( + type="list", + element_id=3, + element_type=StructType( + NestedField(field_id=4, name="latitude", field_type=DoubleType(), required=False), + NestedField(field_id=5, name="longitude", field_type=DoubleType(), required=False), + ), + element_required=True, + ), + required=True, + ), + ) + + +allowed_promotions = [ + (StringType(), BinaryType()), + (BinaryType(), StringType()), + (IntegerType(), LongType()), + (FloatType(), DoubleType()), + (DecimalType(9, 2), DecimalType(18, 2)), +] + + +@pytest.mark.parametrize("from_type, to_type", allowed_promotions, ids=str) +@pytest.mark.integration +def test_allowed_updates(from_type: PrimitiveType, to_type: PrimitiveType, catalog: Catalog) -> None: + tbl = _create_table_with_schema( + catalog, + Schema( + NestedField(field_id=1, name="bar", field_type=from_type, required=True), + ), + ) + + with tbl.update_schema() as schema_update: + schema_update.update_column("bar", to_type) + + assert tbl.schema() == Schema(NestedField(field_id=1, name="bar", field_type=to_type, required=True)) + + +disallowed_promotions_types = [ + BooleanType(), + IntegerType(), + LongType(), + FloatType(), + DoubleType(), + DateType(), + TimeType(), + TimestampType(), + TimestamptzType(), + StringType(), + UUIDType(), + BinaryType(), + FixedType(3), + FixedType(4), + # We'll just allow Decimal promotions right now + # https://github.com/apache/iceberg/issues/8389 + # DecimalType(9, 2), + # DecimalType(9, 3), + DecimalType(18, 2), +] + + +@pytest.mark.parametrize("from_type", disallowed_promotions_types, ids=str) +@pytest.mark.parametrize("to_type", disallowed_promotions_types, ids=str) +@pytest.mark.integration +def test_disallowed_updates(from_type: PrimitiveType, to_type: PrimitiveType, catalog: Catalog) -> None: + tbl = _create_table_with_schema( + catalog, + Schema( + NestedField(field_id=1, name="bar", field_type=from_type, required=True), + ), + ) + + if from_type != to_type and (from_type, to_type) not in allowed_promotions: + with pytest.raises(ValidationError) as exc_info: + with tbl.update_schema() as schema_update: + schema_update.update_column("bar", to_type) + + assert str(exc_info.value).startswith("Cannot change column type: bar:") + else: + with tbl.update_schema() as schema_update: + schema_update.update_column("bar", to_type) + + assert tbl.schema() == Schema( + NestedField(field_id=1, name="bar", field_type=to_type, required=True), + ) + + +@pytest.mark.integration +def test_rename_simple(simple_table: Table) -> None: + with simple_table.update_schema() as schema_update: + schema_update.rename_column("foo", "vo") + + assert simple_table.schema() == Schema( + NestedField(field_id=1, name="vo", field_type=StringType(), required=False), + NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True), + NestedField(field_id=3, name="baz", field_type=BooleanType(), required=False), + identifier_field_ids=[2], + ) + + +@pytest.mark.integration +def test_rename_simple_nested(catalog: Catalog) -> None: + tbl = _create_table_with_schema( + catalog, + Schema( + NestedField( + field_id=1, + name="foo", + field_type=StructType(NestedField(field_id=2, name="bar", field_type=StringType())), + required=True, + ), + ), + ) + + with tbl.update_schema() as schema_update: + schema_update.rename_column("foo.bar", "vo") + + assert tbl.schema() == Schema( + NestedField( + field_id=1, + name="foo", + field_type=StructType(NestedField(field_id=2, name="vo", field_type=StringType())), + required=True, + ), + ) + + +@pytest.mark.integration +def test_rename_simple_nested_with_dots(catalog: Catalog) -> None: + tbl = _create_table_with_schema( + catalog, + Schema( + NestedField( + field_id=1, + name="a.b", + field_type=StructType(NestedField(field_id=2, name="c.d", field_type=StringType())), + required=True, + ), + ), + ) + + with tbl.update_schema() as schema_update: + schema_update.rename_column(("a.b", "c.d"), "e.f") + + assert tbl.schema() == Schema( + NestedField( + field_id=1, + name="a.b", + field_type=StructType(NestedField(field_id=2, name="e.f", field_type=StringType())), + required=True, + ), + ) + + +@pytest.mark.integration +def test_rename(catalog: Catalog) -> None: + tbl = _create_table_with_schema( + catalog, + Schema( + NestedField( + field_id=1, + name="location_lookup", + field_type=MapType( + type="map", + key_id=5, + key_type=StringType(), + value_id=6, + value_type=StructType( + NestedField(field_id=7, name="x", field_type=FloatType(), required=False), + NestedField(field_id=8, name="y", field_type=FloatType(), required=False), + ), + value_required=True, + ), + required=True, + ), + NestedField( + field_id=2, + name="locations", + field_type=ListType( + type="list", + element_id=9, + element_type=StructType( + NestedField(field_id=10, name="x", field_type=FloatType(), required=False), + NestedField(field_id=11, name="y", field_type=FloatType(), required=False), + ), + element_required=True, + ), + required=True, + ), + NestedField( + field_id=3, + name="person", + field_type=StructType( + NestedField(field_id=12, name="name", field_type=StringType(), required=False), + NestedField(field_id=13, name="leeftijd", field_type=IntegerType(), required=True), + ), + required=False, + ), + NestedField(field_id=4, name="foo", field_type=StringType(), required=True), + identifier_field_ids=[], + ), + ) + + with tbl.update_schema() as schema_update: + schema_update.rename_column("foo", "bar") + schema_update.rename_column("location_lookup.x", "latitude") + schema_update.rename_column("locations.x", "latitude") + schema_update.rename_column("person.leeftijd", "age") + + assert tbl.schema() == Schema( + NestedField( + field_id=1, + name="location_lookup", + field_type=MapType( + type="map", + key_id=5, + key_type=StringType(), + value_id=6, + value_type=StructType( + NestedField(field_id=7, name="latitude", field_type=FloatType(), required=False), + NestedField(field_id=8, name="y", field_type=FloatType(), required=False), + ), + value_required=True, + ), + required=True, + ), + NestedField( + field_id=2, + name="locations", + field_type=ListType( + type="list", + element_id=9, + element_type=StructType( + NestedField(field_id=10, name="latitude", field_type=FloatType(), required=False), + NestedField(field_id=11, name="y", field_type=FloatType(), required=False), + ), + element_required=True, + ), + required=True, + ), + NestedField( + field_id=3, + name="person", + field_type=StructType( + NestedField(field_id=12, name="name", field_type=StringType(), required=False), + NestedField(field_id=13, name="age", field_type=IntegerType(), required=True), + ), + required=False, + ), + NestedField(field_id=4, name="bar", field_type=StringType(), required=True), + identifier_field_ids=[], + ) + + +@pytest.mark.integration +def test_rename_case_insensitive(catalog: Catalog) -> None: + tbl = _create_table_with_schema( + catalog, + Schema( + NestedField( + field_id=1, + name="location_lookup", + field_type=MapType( + type="map", + key_id=5, + key_type=StringType(), + value_id=6, + value_type=StructType( + NestedField(field_id=7, name="x", field_type=FloatType(), required=False), + NestedField(field_id=8, name="y", field_type=FloatType(), required=False), + ), + value_required=True, + ), + required=True, + ), + NestedField( + field_id=2, + name="locations", + field_type=ListType( + type="list", + element_id=9, + element_type=StructType( + NestedField(field_id=10, name="x", field_type=FloatType(), required=False), + NestedField(field_id=11, name="y", field_type=FloatType(), required=False), + ), + element_required=True, + ), + required=True, + ), + NestedField( + field_id=3, + name="person", + field_type=StructType( + NestedField(field_id=12, name="name", field_type=StringType(), required=False), + NestedField(field_id=13, name="leeftijd", field_type=IntegerType(), required=True), + ), + required=True, + ), + NestedField(field_id=4, name="foo", field_type=StringType(), required=True), + identifier_field_ids=[13], + ), + ) + + with tbl.update_schema(case_sensitive=False) as schema_update: + schema_update.rename_column("Foo", "bar") + schema_update.rename_column("Location_lookup.X", "latitude") + schema_update.rename_column("Locations.X", "latitude") + schema_update.rename_column("Person.Leeftijd", "age") + + assert tbl.schema() == Schema( + NestedField( + field_id=1, + name="location_lookup", + field_type=MapType( + type="map", + key_id=5, + key_type=StringType(), + value_id=6, + value_type=StructType( + NestedField(field_id=7, name="latitude", field_type=FloatType(), required=False), + NestedField(field_id=8, name="y", field_type=FloatType(), required=False), + ), + value_required=True, + ), + required=True, + ), + NestedField( + field_id=2, + name="locations", + field_type=ListType( + type="list", + element_id=9, + element_type=StructType( + NestedField(field_id=10, name="latitude", field_type=FloatType(), required=False), + NestedField(field_id=11, name="y", field_type=FloatType(), required=False), + ), + element_required=True, + ), + required=True, + ), + NestedField( + field_id=3, + name="person", + field_type=StructType( + NestedField(field_id=12, name="name", field_type=StringType(), required=False), + NestedField(field_id=13, name="age", field_type=IntegerType(), required=True), + ), + required=True, + ), + NestedField(field_id=4, name="bar", field_type=StringType(), required=True), + identifier_field_ids=[13], + ) + + +@pytest.mark.integration +def test_add_struct(catalog: Catalog) -> None: + tbl = _create_table_with_schema( + catalog, + Schema( + NestedField(field_id=1, name="foo", field_type=StringType()), + ), + ) + + struct = StructType( + NestedField(field_id=3, name="x", field_type=DoubleType(), required=False), + NestedField(field_id=4, name="y", field_type=DoubleType(), required=False), + ) + + with tbl.update_schema() as schema_update: + schema_update.add_column("location", struct) + + assert tbl.schema() == Schema( + NestedField(field_id=1, name="foo", field_type=StringType()), + NestedField(field_id=2, name="location", field_type=struct, required=False), + ) + + +@pytest.mark.integration +def test_add_nested_map_of_structs(catalog: Catalog) -> None: + tbl = _create_table_with_schema( + catalog, + Schema( + NestedField(field_id=1, name="foo", field_type=StringType()), + ), + ) + + map_type_example = MapType( + key_id=1, + value_id=2, + key_type=StructType( + NestedField(field_id=20, name="address", field_type=StringType(), required=True), + NestedField(field_id=21, name="city", field_type=StringType(), required=True), + NestedField(field_id=22, name="state", field_type=StringType(), required=True), + NestedField(field_id=23, name="zip", field_type=IntegerType(), required=True), + ), + value_type=StructType( + NestedField(field_id=9, name="lat", field_type=DoubleType(), required=True), + NestedField(field_id=8, name="long", field_type=DoubleType(), required=False), + ), + ) + + with tbl.update_schema() as schema_update: + schema_update.add_column("locations", map_type_example) + + assert tbl.schema() == Schema( + NestedField(field_id=1, name="foo", field_type=StringType(), required=True), + NestedField( + field_id=2, + name="locations", + field_type=MapType( + type="map", + key_id=3, + key_type=StructType( + NestedField(field_id=5, name="address", field_type=StringType(), required=True), + NestedField(field_id=6, name="city", field_type=StringType(), required=True), + NestedField(field_id=7, name="state", field_type=StringType(), required=True), + NestedField(field_id=8, name="zip", field_type=IntegerType(), required=True), + ), + value_id=4, + value_type=StructType( + NestedField(field_id=9, name="lat", field_type=DoubleType(), required=True), + NestedField(field_id=10, name="long", field_type=DoubleType(), required=False), + ), + value_required=True, + ), + required=False, + ), + ) + + +@pytest.mark.integration +def test_add_nested_list_of_structs(catalog: Catalog) -> None: + tbl = _create_table_with_schema( + catalog, + Schema( + NestedField(field_id=1, name="foo", field_type=StringType()), + ), + ) + + list_type_examples = ListType( + element_id=1, + element_type=StructType( + NestedField(field_id=9, name="lat", field_type=DoubleType(), required=True), + NestedField(field_id=10, name="long", field_type=DoubleType(), required=False), + ), + element_required=False, + ) + + with tbl.update_schema() as schema_update: + schema_update.add_column("locations", list_type_examples) + + assert tbl.schema() == Schema( + NestedField(field_id=1, name="foo", field_type=StringType(), required=True), + NestedField( + field_id=2, + name="locations", + field_type=ListType( + type="list", + element_id=3, + element_type=StructType( + NestedField(field_id=4, name="lat", field_type=DoubleType(), required=True), + NestedField(field_id=5, name="long", field_type=DoubleType(), required=False), + ), + element_required=False, + ), + required=False, + ), + ) + + +@pytest.mark.integration +def test_add_required_column(catalog: Catalog) -> None: + schema_ = Schema(NestedField(field_id=1, name="a", field_type=BooleanType(), required=False)) + table = _create_table_with_schema(catalog, schema_) + update = UpdateSchema(table) + with pytest.raises(ValueError) as exc_info: + update.add_column(path="data", field_type=IntegerType(), required=True) + assert "Incompatible change: cannot add required column: data" in str(exc_info.value) + + new_schema = ( + UpdateSchema(table, allow_incompatible_changes=True) # pylint: disable=W0212 + .add_column(path="data", field_type=IntegerType(), required=True) + ._apply() + ) + assert new_schema == Schema( + NestedField(field_id=1, name="a", field_type=BooleanType(), required=False), + NestedField(field_id=2, name="data", field_type=IntegerType(), required=True), + ) + + +@pytest.mark.integration +def test_add_required_column_case_insensitive(catalog: Catalog) -> None: + schema_ = Schema(NestedField(field_id=1, name="id", field_type=BooleanType(), required=False)) + table = _create_table_with_schema(catalog, schema_) + + with pytest.raises(ValueError) as exc_info: + with UpdateSchema(table, allow_incompatible_changes=True) as update: + update.case_sensitive(False).add_column(path="ID", field_type=IntegerType(), required=True) + assert "already exists: ID" in str(exc_info.value) + + new_schema = ( + UpdateSchema(table, allow_incompatible_changes=True) # pylint: disable=W0212 + .add_column(path="ID", field_type=IntegerType(), required=True) + ._apply() + ) + assert new_schema == Schema( + NestedField(field_id=1, name="id", field_type=BooleanType(), required=False), + NestedField(field_id=2, name="ID", field_type=IntegerType(), required=True), + ) + + +@pytest.mark.integration +def test_make_column_optional(catalog: Catalog) -> None: + tbl = _create_table_with_schema( + catalog, + Schema( + NestedField(field_id=1, name="foo", field_type=StringType(), required=True), + ), + ) + + with tbl.update_schema() as schema_update: + schema_update.make_column_optional("foo") + + assert tbl.schema() == Schema( + NestedField(field_id=1, name="foo", field_type=StringType(), required=False), + ) + + +@pytest.mark.integration +def test_mixed_changes(catalog: Catalog) -> None: + tbl = _create_table_with_schema( + catalog, + Schema( + NestedField(field_id=1, name="id", field_type=StringType(), required=True), + NestedField(field_id=2, name="data", field_type=StringType(), required=False), + NestedField( + field_id=3, + name="preferences", + field_type=StructType( + NestedField(field_id=8, name="feature1", type=BooleanType(), required=True), + NestedField(field_id=9, name="feature2", type=BooleanType(), required=False), + ), + required=False, + ), + NestedField( + field_id=4, + name="locations", + field_type=MapType( + key_id=10, + value_id=11, + key_type=StructType( + NestedField(field_id=20, name="address", field_type=StringType(), required=True), + NestedField(field_id=21, name="city", field_type=StringType(), required=True), + NestedField(field_id=22, name="state", field_type=StringType(), required=True), + NestedField(field_id=23, name="zip", field_type=IntegerType(), required=True), + ), + value_type=StructType( + NestedField(field_id=12, name="lat", field_type=DoubleType(), required=True), + NestedField(field_id=13, name="long", field_type=DoubleType(), required=False), + ), + ), + required=True, + ), + NestedField( + field_id=5, + name="points", + field_type=ListType( + element_id=14, + element_type=StructType( + NestedField(field_id=15, name="x", field_type=LongType(), required=True), + NestedField(field_id=16, name="y", field_type=LongType(), required=True), + ), + ), + required=True, + doc="2-D cartesian points", + ), + NestedField(field_id=6, name="doubles", field_type=ListType(element_id=17, element_type=DoubleType()), required=True), + NestedField( + field_id=7, + name="properties", + field_type=MapType(key_id=18, value_id=19, key_type=StringType(), value_type=StringType()), + required=False, + ), + ), + ) + + with tbl.update_schema(allow_incompatible_changes=True) as schema_update: + schema_update.add_column("toplevel", field_type=DecimalType(9, 2)) + schema_update.add_column(("locations", "alt"), field_type=FloatType()) + schema_update.add_column(("points", "z"), field_type=LongType()) + schema_update.add_column(("points", "t.t"), field_type=LongType(), doc="name with '.'") + schema_update.rename_column("data", "json") + schema_update.rename_column("preferences", "options") + schema_update.rename_column("preferences.feature2", "newfeature") + schema_update.rename_column("locations.lat", "latitude") + schema_update.rename_column("points.x", "X") + schema_update.rename_column("points.y", "y.y") + schema_update.update_column("id", field_type=LongType(), doc="unique id") + schema_update.update_column("locations.lat", DoubleType()) + schema_update.update_column("locations.lat", doc="latitude") + schema_update.delete_column("locations.long") + schema_update.delete_column("properties") + schema_update.make_column_optional("points.x") + schema_update.update_column("data", required=True) + schema_update.add_column(("locations", "description"), StringType(), doc="location description") + + assert tbl.schema() == Schema( + NestedField(field_id=1, name="id", field_type=LongType(), required=True, doc="unique id"), + NestedField(field_id=2, name="json", field_type=StringType(), required=True), + NestedField( + field_id=3, + name="options", + field_type=StructType( + NestedField(field_id=8, name="feature1", field_type=BooleanType(), required=True), + NestedField(field_id=9, name="newfeature", field_type=BooleanType(), required=False), + ), + required=False, + ), + NestedField( + field_id=4, + name="locations", + field_type=MapType( + type="map", + key_id=10, + key_type=StructType( + NestedField(field_id=12, name="address", field_type=StringType(), required=True), + NestedField(field_id=13, name="city", field_type=StringType(), required=True), + NestedField(field_id=14, name="state", field_type=StringType(), required=True), + NestedField(field_id=15, name="zip", field_type=IntegerType(), required=True), + ), + value_id=11, + value_type=StructType( + NestedField(field_id=16, name="latitude", field_type=DoubleType(), required=True, doc="latitude"), + NestedField(field_id=25, name="alt", field_type=FloatType(), required=False), + NestedField( + field_id=28, name="description", field_type=StringType(), required=False, doc="location description" + ), + ), + value_required=True, + ), + required=True, + ), + NestedField( + field_id=5, + name="points", + field_type=ListType( + type="list", + element_id=18, + element_type=StructType( + NestedField(field_id=19, name="X", field_type=LongType(), required=False), + NestedField(field_id=20, name="y.y", field_type=LongType(), required=True), + NestedField(field_id=26, name="z", field_type=LongType(), required=False), + NestedField(field_id=27, name="t.t", field_type=LongType(), required=False, doc="name with '.'"), + ), + element_required=True, + ), + doc="2-D cartesian points", + required=True, + ), + NestedField( + field_id=6, + name="doubles", + field_type=ListType(type="list", element_id=21, element_type=DoubleType(), element_required=True), + required=True, + ), + NestedField(field_id=24, name="toplevel", field_type=DecimalType(precision=9, scale=2), required=False), + ) + + +@pytest.mark.integration +def test_ambiguous_column(catalog: Catalog, table_schema_nested: Schema) -> None: + table = _create_table_with_schema(catalog, table_schema_nested) + update = UpdateSchema(table) + + with pytest.raises(ValueError) as exc_info: + update.add_column(path="location.latitude", field_type=IntegerType()) + assert "Cannot add column with ambiguous name: location.latitude, provide a tuple instead" in str(exc_info.value) + + +@pytest.mark.integration +def test_delete_then_add(catalog: Catalog) -> None: + tbl = _create_table_with_schema( + catalog, + Schema( + NestedField(field_id=1, name="foo", field_type=StringType(), required=True), + ), + ) + + with tbl.update_schema() as schema_update: + schema_update.delete_column("foo") + schema_update.add_column("foo", StringType()) + + assert tbl.schema() == Schema( + NestedField(field_id=2, name="foo", field_type=StringType(), required=False), + ) + + +@pytest.mark.integration +def test_delete_then_add_nested(catalog: Catalog) -> None: + tbl = _create_table_with_schema( + catalog, + Schema( + NestedField( + field_id=1, + name="preferences", + field_type=StructType( + NestedField(field_id=2, name="feature1", field_type=BooleanType()), + NestedField(field_id=3, name="feature2", field_type=BooleanType()), + ), + required=True, + ), + ), + ) + + with tbl.update_schema() as schema_update: + schema_update.delete_column("preferences.feature1") + schema_update.add_column(("preferences", "feature1"), BooleanType()) + + assert tbl.schema() == Schema( + NestedField( + field_id=1, + name="preferences", + field_type=StructType( + NestedField(field_id=3, name="feature2", field_type=BooleanType()), + NestedField(field_id=4, name="feature1", field_type=BooleanType(), required=False), + ), + required=True, + ), + ) + + +@pytest.mark.integration +def test_delete_missing_column(catalog: Catalog) -> None: + tbl = _create_table_with_schema( + catalog, + Schema( + NestedField(field_id=1, name="foo", field_type=StringType(), required=True), + ), + ) + + with pytest.raises(ValueError) as exc_info: + with tbl.update_schema() as schema_update: + schema_update.delete_column("bar") + + assert "Could not find field with name bar, case_sensitive=True" in str(exc_info.value) + + +@pytest.mark.integration +def test_add_delete_conflict(catalog: Catalog) -> None: + tbl = _create_table_with_schema( + catalog, + Schema( + NestedField(field_id=1, name="foo", field_type=StringType(), required=True), + ), + ) + + with pytest.raises(ValueError) as exc_info: + with tbl.update_schema() as schema_update: + schema_update.add_column("bar", BooleanType()) + schema_update.delete_column("bar") + assert "Could not find field with name bar, case_sensitive=True" in str(exc_info.value) + + tbl = _create_table_with_schema( + catalog, + Schema( + NestedField( + field_id=1, + name="preferences", + field_type=StructType( + NestedField(field_id=2, name="feature1", field_type=BooleanType()), + NestedField(field_id=3, name="feature2", field_type=BooleanType()), + ), + required=True, + ), + ), + ) + + with pytest.raises(ValueError) as exc_info: + with tbl.update_schema() as schema_update: + schema_update.add_column(("preferences", "feature3"), BooleanType()) + schema_update.delete_column("preferences") + assert "Cannot delete a column that has additions: preferences" in str(exc_info.value) + + +@pytest.mark.integration +def test_rename_missing_column(catalog: Catalog) -> None: + tbl = _create_table_with_schema( + catalog, + Schema( + NestedField(field_id=1, name="foo", field_type=StringType(), required=True), + ), + ) + + with pytest.raises(ValueError) as exc_info: + with tbl.update_schema() as schema_update: + schema_update.rename_column("bar", "fail") + + assert "Could not find field with name bar, case_sensitive=True" in str(exc_info.value) + + +@pytest.mark.integration +def test_rename_missing_conflicts(catalog: Catalog) -> None: + tbl = _create_table_with_schema( + catalog, + Schema( + NestedField(field_id=1, name="foo", field_type=StringType(), required=True), + ), + ) + + with pytest.raises(ValueError) as exc_info: + with tbl.update_schema() as schema_update: + schema_update.rename_column("foo", "bar") + schema_update.delete_column("foo") + + assert "Cannot delete a column that has updates: foo" in str(exc_info.value) + + tbl = _create_table_with_schema( + catalog, + Schema( + NestedField(field_id=1, name="foo", field_type=StringType(), required=True), + ), + ) + + with pytest.raises(ValueError) as exc_info: + with tbl.update_schema() as schema_update: + schema_update.rename_column("foo", "bar") + schema_update.delete_column("bar") + + assert "Could not find field with name bar, case_sensitive=True" in str(exc_info.value) + + +@pytest.mark.integration +def test_update_missing_column(catalog: Catalog) -> None: + tbl = _create_table_with_schema( + catalog, + Schema( + NestedField(field_id=1, name="foo", field_type=StringType(), required=True), + ), + ) + + with pytest.raises(ValueError) as exc_info: + with tbl.update_schema() as schema_update: + schema_update.update_column("bar", DateType()) + + assert "Could not find field with name bar, case_sensitive=True" in str(exc_info.value) + + +@pytest.mark.integration +def test_update_delete_conflict(catalog: Catalog) -> None: + tbl = _create_table_with_schema( + catalog, + Schema( + NestedField(field_id=1, name="foo", field_type=IntegerType(), required=True), + ), + ) + + with pytest.raises(ValueError) as exc_info: + with tbl.update_schema() as schema_update: + schema_update.update_column("foo", LongType()) + schema_update.delete_column("foo") + + assert "Cannot delete a column that has updates: foo" in str(exc_info.value) + + +@pytest.mark.integration +def test_delete_update_conflict(catalog: Catalog) -> None: + tbl = _create_table_with_schema( + catalog, + Schema( + NestedField(field_id=1, name="foo", field_type=IntegerType(), required=True), + ), + ) + + with pytest.raises(ValueError) as exc_info: + with tbl.update_schema() as schema_update: + schema_update.delete_column("foo") + schema_update.update_column("foo", LongType()) + + assert "Cannot update a column that will be deleted: foo" in str(exc_info.value) + + +@pytest.mark.integration +def test_delete_map_key(nested_table: Table) -> None: + with pytest.raises(ValueError) as exc_info: + with nested_table.update_schema() as schema_update: + schema_update.delete_column("quux.key") + + assert "Cannot delete map keys" in str(exc_info.value) + + +@pytest.mark.integration +def test_add_field_to_map_key(nested_table: Table) -> None: + with pytest.raises(ValueError) as exc_info: + with nested_table.update_schema() as schema_update: + schema_update.add_column(("quux", "key"), StringType()) + + assert "Cannot add column 'key' to non-struct type: quux" in str(exc_info.value) + + +@pytest.mark.integration +def test_alter_map_key(nested_table: Table) -> None: + with pytest.raises(ValueError) as exc_info: + with nested_table.update_schema() as schema_update: + schema_update.update_column(("quux", "key"), BinaryType()) + + assert "Cannot update map keys" in str(exc_info.value) + + +@pytest.mark.integration +def test_update_map_key(catalog: Catalog) -> None: + tbl = _create_table_with_schema( + catalog, + Schema( + NestedField( + field_id=1, name="m", field_type=MapType(key_id=2, value_id=3, key_type=IntegerType(), value_type=DoubleType()) + ) + ), + ) + with pytest.raises(ValueError) as exc_info: + with tbl.update_schema() as schema_update: + schema_update.update_column("m.key", LongType()) + + assert "Cannot update map keys: map" in str(exc_info.value) + + +@pytest.mark.integration +def test_update_added_column_doc(catalog: Catalog) -> None: + tbl = _create_table_with_schema( + catalog, + Schema( + NestedField(field_id=1, name="foo", field_type=StringType(), required=True), + ), + ) + + with pytest.raises(ValueError) as exc_info: + with tbl.update_schema() as schema_update: + schema_update.add_column("value", LongType()) + schema_update.update_column("value", doc="a value") + + assert "Could not find field with name value, case_sensitive=True" in str(exc_info.value) + + +@pytest.mark.integration +def test_update_deleted_column_doc(catalog: Catalog) -> None: + tbl = _create_table_with_schema( + catalog, + Schema( + NestedField(field_id=1, name="foo", field_type=StringType(), required=True), + ), + ) + + with pytest.raises(ValueError) as exc_info: + with tbl.update_schema() as schema_update: + schema_update.delete_column("foo") + schema_update.update_column("foo", doc="a value") + + assert "Cannot update a column that will be deleted: foo" in str(exc_info.value) + + +@pytest.mark.integration +def test_multiple_moves(catalog: Catalog) -> None: + tbl = _create_table_with_schema( + catalog, + Schema( + NestedField(field_id=1, name="a", field_type=IntegerType(), required=True), + NestedField(field_id=2, name="b", field_type=IntegerType(), required=True), + NestedField(field_id=3, name="c", field_type=IntegerType(), required=True), + NestedField(field_id=4, name="d", field_type=IntegerType(), required=True), + ), + ) + + with tbl.update_schema() as schema_update: + schema_update.move_first("d") + schema_update.move_first("c") + schema_update.move_after("b", "d") + schema_update.move_before("d", "a") + + assert tbl.schema() == Schema( + NestedField(field_id=3, name="c", field_type=IntegerType(), required=True), + NestedField(field_id=2, name="b", field_type=IntegerType(), required=True), + NestedField(field_id=4, name="d", field_type=IntegerType(), required=True), + NestedField(field_id=1, name="a", field_type=IntegerType(), required=True), + ) + + +@pytest.mark.integration +def test_move_top_level_column_first(catalog: Catalog) -> None: + tbl = _create_table_with_schema( + catalog, + Schema( + NestedField(field_id=1, name="id", field_type=LongType(), required=True), + NestedField(field_id=2, name="data", field_type=StringType(), required=True), + ), + ) + + with tbl.update_schema() as schema_update: + schema_update.move_first("data") + + assert tbl.schema() == Schema( + NestedField(field_id=2, name="data", field_type=StringType(), required=True), + NestedField(field_id=1, name="id", field_type=LongType(), required=True), + ) + + +@pytest.mark.integration +def test_move_top_level_column_before_first(catalog: Catalog) -> None: + tbl = _create_table_with_schema( + catalog, + Schema( + NestedField(field_id=1, name="id", field_type=LongType(), required=True), + NestedField(field_id=2, name="data", field_type=StringType(), required=True), + ), + ) + + with tbl.update_schema() as schema_update: + schema_update.move_before("data", "id") + + assert tbl.schema() == Schema( + NestedField(field_id=2, name="data", field_type=StringType(), required=True), + NestedField(field_id=1, name="id", field_type=LongType(), required=True), + ) + + +@pytest.mark.integration +def test_move_top_level_column_after_last(catalog: Catalog) -> None: + tbl = _create_table_with_schema( + catalog, + Schema( + NestedField(field_id=1, name="id", field_type=LongType(), required=True), + NestedField(field_id=2, name="data", field_type=StringType(), required=True), + ), + ) + + with tbl.update_schema() as schema_update: + schema_update.move_after("id", "data") + + assert tbl.schema() == Schema( + NestedField(field_id=2, name="data", field_type=StringType(), required=True), + NestedField(field_id=1, name="id", field_type=LongType(), required=True), + ) + + +@pytest.mark.integration +def test_move_nested_field_first(catalog: Catalog) -> None: + tbl = _create_table_with_schema( + catalog, + Schema( + NestedField(field_id=1, name="id", field_type=LongType(), required=True), + NestedField( + field_id=2, + name="struct", + field_type=StructType( + NestedField(field_id=3, name="count", field_type=LongType(), required=True), + NestedField(field_id=4, name="data", field_type=StringType(), required=True), + ), + required=True, + ), + ), + ) + + with tbl.update_schema() as schema_update: + schema_update.move_first("struct.data") + + assert tbl.schema() == Schema( + NestedField(field_id=1, name="id", field_type=LongType(), required=True), + NestedField( + field_id=2, + name="struct", + field_type=StructType( + NestedField(field_id=4, name="data", field_type=StringType(), required=True), + NestedField(field_id=3, name="count", field_type=LongType(), required=True), + ), + required=True, + ), + ) + + +@pytest.mark.integration +def test_move_nested_field_before_first(catalog: Catalog) -> None: + tbl = _create_table_with_schema( + catalog, + Schema( + NestedField(field_id=1, name="id", field_type=LongType(), required=True), + NestedField( + field_id=2, + name="struct", + field_type=StructType( + NestedField(field_id=3, name="count", field_type=LongType(), required=True), + NestedField(field_id=4, name="data", field_type=StringType(), required=True), + ), + required=True, + ), + ), + ) + + with tbl.update_schema() as schema_update: + schema_update.move_before("struct.data", "struct.count") + + assert tbl.schema() == Schema( + NestedField(field_id=1, name="id", field_type=LongType(), required=True), + NestedField( + field_id=2, + name="struct", + field_type=StructType( + NestedField(field_id=4, name="data", field_type=StringType(), required=True), + NestedField(field_id=3, name="count", field_type=LongType(), required=True), + ), + required=True, + ), + ) + + +@pytest.mark.integration +def test_move_nested_field_after_first(catalog: Catalog) -> None: + tbl = _create_table_with_schema( + catalog, + Schema( + NestedField(field_id=1, name="id", field_type=LongType(), required=True), + NestedField( + field_id=2, + name="struct", + field_type=StructType( + NestedField(field_id=3, name="count", field_type=LongType(), required=True), + NestedField(field_id=4, name="data", field_type=StringType(), required=True), + ), + required=True, + ), + ), + ) + + with tbl.update_schema() as schema_update: + schema_update.move_before("struct.data", "struct.count") + + assert str(tbl.schema()) == str( + Schema( + NestedField(field_id=1, name="id", field_type=LongType(), required=True), + NestedField( + field_id=2, + name="struct", + field_type=StructType( + NestedField(field_id=4, name="data", field_type=StringType(), required=True), + NestedField(field_id=3, name="count", field_type=LongType(), required=True), + ), + required=True, + ), + ) + ) + + +@pytest.mark.integration +def test_move_nested_field_after(catalog: Catalog) -> None: + tbl = _create_table_with_schema( + catalog, + Schema( + NestedField(field_id=1, name="id", field_type=LongType(), required=True), + NestedField( + field_id=2, + name="struct", + field_type=StructType( + NestedField(field_id=3, name="count", field_type=LongType(), required=True), + NestedField(field_id=4, name="data", field_type=StringType(), required=True), + NestedField(field_id=5, name="ts", field_type=TimestamptzType(), required=True), + ), + required=True, + ), + ), + ) + + with tbl.update_schema() as schema_update: + schema_update.move_after("struct.ts", "struct.count") + + assert tbl.schema() == Schema( + NestedField(field_id=1, name="id", field_type=LongType(), required=True), + NestedField( + field_id=2, + name="struct", + field_type=StructType( + NestedField(field_id=3, name="count", field_type=LongType(), required=True), + NestedField(field_id=5, name="ts", field_type=TimestamptzType(), required=True), + NestedField(field_id=4, name="data", field_type=StringType(), required=True), + ), + required=True, + ), + ) + + +@pytest.mark.integration +def test_move_nested_field_before(catalog: Catalog) -> None: + tbl = _create_table_with_schema( + catalog, + Schema( + NestedField(field_id=1, name="id", field_type=LongType(), required=True), + NestedField( + field_id=2, + name="struct", + field_type=StructType( + NestedField(field_id=3, name="count", field_type=LongType(), required=True), + NestedField(field_id=4, name="data", field_type=StringType(), required=True), + NestedField(field_id=5, name="ts", field_type=TimestamptzType(), required=True), + ), + required=True, + ), + ), + ) + + with tbl.update_schema() as schema_update: + schema_update.move_before("struct.ts", "struct.data") + + assert tbl.schema() == Schema( + NestedField(field_id=1, name="id", field_type=LongType(), required=True), + NestedField( + field_id=2, + name="struct", + field_type=StructType( + NestedField(field_id=3, name="count", field_type=LongType(), required=True), + NestedField(field_id=5, name="ts", field_type=TimestamptzType(), required=True), + NestedField(field_id=4, name="data", field_type=StringType(), required=True), + ), + required=True, + ), + ) + + +@pytest.mark.integration +def test_move_map_value_struct_field(catalog: Catalog) -> None: + tbl = _create_table_with_schema( + catalog, + Schema( + NestedField(field_id=1, name="id", field_type=LongType(), required=True), + NestedField( + field_id=2, + name="map", + field_type=MapType( + key_id=3, + value_id=4, + key_type=StringType(), + value_type=StructType( + NestedField(field_id=5, name="ts", field_type=TimestamptzType(), required=True), + NestedField(field_id=6, name="count", field_type=LongType(), required=True), + NestedField(field_id=7, name="data", field_type=StringType(), required=True), + ), + ), + required=True, + ), + ), + ) + + with tbl.update_schema() as schema_update: + schema_update.move_before("map.ts", "map.data") + + assert tbl.schema() == Schema( + NestedField(field_id=1, name="id", field_type=LongType(), required=True), + NestedField( + field_id=2, + name="map", + field_type=MapType( + key_id=3, + value_id=4, + key_type=StringType(), + value_type=StructType( + NestedField(field_id=6, name="count", field_type=LongType(), required=True), + NestedField(field_id=5, name="ts", field_type=TimestamptzType(), required=True), + NestedField(field_id=7, name="data", field_type=StringType(), required=True), + ), + ), + required=True, + ), + ) + + +@pytest.mark.integration +def test_move_added_top_level_column(catalog: Catalog) -> None: + tbl = _create_table_with_schema( + catalog, + Schema( + NestedField(field_id=1, name="id", field_type=LongType(), required=True), + NestedField(field_id=2, name="data", field_type=StringType(), required=True), + ), + ) + + with tbl.update_schema() as schema_update: + schema_update.add_column("ts", TimestamptzType()) + schema_update.move_after("ts", "id") + + assert tbl.schema() == Schema( + NestedField(field_id=1, name="id", field_type=LongType(), required=True), + NestedField(field_id=3, name="ts", field_type=TimestamptzType(), required=False), + NestedField(field_id=2, name="data", field_type=StringType(), required=True), + ) + + +@pytest.mark.integration +def test_move_added_top_level_column_after_added_column(catalog: Catalog) -> None: + tbl = _create_table_with_schema( + catalog, + Schema( + NestedField(field_id=1, name="id", field_type=LongType(), required=True), + NestedField(field_id=2, name="data", field_type=StringType(), required=True), + ), + ) + + with tbl.update_schema() as schema_update: + schema_update.add_column("ts", TimestamptzType()) + schema_update.add_column("count", LongType()) + schema_update.move_after("ts", "id") + schema_update.move_after("count", "ts") + + assert tbl.schema() == Schema( + NestedField(field_id=1, name="id", field_type=LongType(), required=True), + NestedField(field_id=3, name="ts", field_type=TimestamptzType(), required=False), + NestedField(field_id=4, name="count", field_type=LongType(), required=False), + NestedField(field_id=2, name="data", field_type=StringType(), required=True), + ) + + +@pytest.mark.integration +def test_move_added_nested_struct_field(catalog: Catalog) -> None: + tbl = _create_table_with_schema( + catalog, + Schema( + NestedField(field_id=1, name="id", field_type=LongType(), required=True), + NestedField( + field_id=2, + name="struct", + field_type=StructType( + NestedField(field_id=3, name="count", field_type=LongType(), required=True), + NestedField(field_id=4, name="data", field_type=StringType(), required=True), + ), + required=True, + ), + ), + ) + + with tbl.update_schema() as schema_update: + schema_update.add_column(("struct", "ts"), TimestamptzType()) + schema_update.move_before("struct.ts", "struct.count") + + assert tbl.schema() == Schema( + NestedField(field_id=1, name="id", field_type=LongType(), required=True), + NestedField( + field_id=2, + name="struct", + field_type=StructType( + NestedField(field_id=5, name="ts", field_type=TimestamptzType(), required=False), + NestedField(field_id=3, name="count", field_type=LongType(), required=True), + NestedField(field_id=4, name="data", field_type=StringType(), required=True), + ), + required=True, + ), + ) + + +@pytest.mark.integration +def test_move_added_nested_field_before_added_column(catalog: Catalog) -> None: + tbl = _create_table_with_schema( + catalog, + Schema( + NestedField(field_id=1, name="id", field_type=LongType(), required=True), + NestedField( + field_id=2, + name="struct", + field_type=StructType( + NestedField(field_id=3, name="count", field_type=LongType(), required=True), + NestedField(field_id=4, name="data", field_type=StringType(), required=True), + ), + required=True, + ), + ), + ) + + with tbl.update_schema() as schema_update: + schema_update.add_column(("struct", "ts"), TimestamptzType()) + schema_update.add_column(("struct", "size"), LongType()) + schema_update.move_before("struct.ts", "struct.count") + schema_update.move_before("struct.size", "struct.ts") + + assert tbl.schema() == Schema( + NestedField(field_id=1, name="id", field_type=LongType(), required=True), + NestedField( + field_id=2, + name="struct", + field_type=StructType( + NestedField(field_id=6, name="size", field_type=LongType(), required=False), + NestedField(field_id=5, name="ts", field_type=TimestamptzType(), required=False), + NestedField(field_id=3, name="count", field_type=LongType(), required=True), + NestedField(field_id=4, name="data", field_type=StringType(), required=True), + ), + required=True, + ), + ) + + +@pytest.mark.integration +def test_move_self_reference_fails(catalog: Catalog) -> None: + tbl = _create_table_with_schema( + catalog, + Schema( + NestedField(field_id=1, name="foo", field_type=StringType()), + ), + ) + + with pytest.raises(ValueError) as exc_info: + with tbl.update_schema() as update: + update.move_before("foo", "foo") + assert "Cannot move foo before itself" in str(exc_info.value) + + with pytest.raises(ValueError) as exc_info: + with tbl.update_schema() as update: + update.move_after("foo", "foo") + assert "Cannot move foo after itself" in str(exc_info.value) + + +@pytest.mark.integration +def test_move_missing_column_fails(catalog: Catalog) -> None: + tbl = _create_table_with_schema( + catalog, + Schema( + NestedField(field_id=1, name="foo", field_type=StringType()), + ), + ) + + with pytest.raises(ValueError) as exc_info: + with tbl.update_schema() as update: + update.move_first("items") + assert "Cannot move missing column: items" in str(exc_info.value) + + with pytest.raises(ValueError) as exc_info: + with tbl.update_schema() as update: + update.move_before("items", "id") + assert "Cannot move missing column: items" in str(exc_info.value) + + with pytest.raises(ValueError) as exc_info: + with tbl.update_schema() as update: + update.move_after("items", "data") + assert "Cannot move missing column: items" in str(exc_info.value) + + +@pytest.mark.integration +def test_move_before_add_fails(catalog: Catalog) -> None: + tbl = _create_table_with_schema( + catalog, + Schema( + NestedField(field_id=1, name="foo", field_type=StringType()), + ), + ) + + with pytest.raises(ValueError) as exc_info: + with tbl.update_schema() as update: + update.move_first("ts") + update.add_column("ts", TimestamptzType()) + assert "Cannot move missing column: ts" in str(exc_info.value) + + with pytest.raises(ValueError) as exc_info: + with tbl.update_schema() as update: + update.move_before("ts", "id") + update.add_column("ts", TimestamptzType()) + assert "Cannot move missing column: ts" in str(exc_info.value) + + with pytest.raises(ValueError) as exc_info: + with tbl.update_schema() as update: + update.move_after("ts", "data") + update.add_column("ts", TimestamptzType()) + assert "Cannot move missing column: ts" in str(exc_info.value) + + +@pytest.mark.integration +def test_move_missing_reference_column_fails(catalog: Catalog) -> None: + tbl = _create_table_with_schema( + catalog, + Schema( + NestedField(field_id=1, name="id", field_type=LongType(), required=True), + NestedField(field_id=2, name="data", field_type=StringType(), required=True), + ), + ) + + with pytest.raises(ValueError) as exc_info: + with tbl.update_schema() as update: + update.move_before("id", "items") + assert "Cannot move id before missing column: items" in str(exc_info.value) + + with pytest.raises(ValueError) as exc_info: + with tbl.update_schema() as update: + update.move_after("data", "items") + assert "Cannot move data after missing column: items" in str(exc_info.value) + + +@pytest.mark.integration +def test_move_primitive_map_key_fails(catalog: Catalog) -> None: + tbl = _create_table_with_schema( + catalog, + Schema( + NestedField(field_id=1, name="id", field_type=LongType(), required=True), + NestedField(field_id=2, name="data", field_type=StringType(), required=True), + NestedField( + field_id=3, + name="map", + field_type=MapType(key_id=4, value_id=5, key_type=StringType(), value_type=StringType()), + required=False, + ), + ), + ) + + with pytest.raises(ValueError) as exc_info: + with tbl.update_schema() as update: + update.move_before("map.key", "map.value") + assert "Cannot move fields in non-struct type: map" in str(exc_info.value) + + +@pytest.mark.integration +def test_move_primitive_map_value_fails(catalog: Catalog) -> None: + tbl = _create_table_with_schema( + catalog, + Schema( + NestedField(field_id=1, name="id", field_type=LongType(), required=True), + NestedField(field_id=2, name="data", field_type=StringType(), required=True), + NestedField( + field_id=3, + name="map", + field_type=MapType(key_id=4, value_id=5, key_type=StringType(), value_type=StructType()), + required=False, + ), + ), + ) + + with pytest.raises(ValueError) as exc_info: + with tbl.update_schema() as update: + update.move_before("map.value", "map.key") + assert "Cannot move fields in non-struct type: map>" in str(exc_info.value) + + +@pytest.mark.integration +def test_move_top_level_between_structs_fails(catalog: Catalog) -> None: + tbl = _create_table_with_schema( + catalog, + Schema( + NestedField(field_id=1, name="a", field_type=IntegerType(), required=True), + NestedField(field_id=2, name="b", field_type=IntegerType(), required=True), + NestedField( + field_id=3, + name="struct", + field_type=StructType( + NestedField(field_id=4, name="x", field_type=IntegerType(), required=True), + NestedField(field_id=5, name="y", field_type=IntegerType(), required=True), + ), + required=False, + ), + ), + ) + + with pytest.raises(ValueError) as exc_info: + with tbl.update_schema() as update: + update.move_before("a", "struct.x") + assert "Cannot move field a to a different struct" in str(exc_info.value) + + +@pytest.mark.integration +def test_move_between_structs_fails(catalog: Catalog) -> None: + tbl = _create_table_with_schema( + catalog, + Schema( + NestedField( + field_id=1, + name="s1", + field_type=StructType( + NestedField(field_id=3, name="a", field_type=IntegerType(), required=True), + NestedField(field_id=4, name="b", field_type=IntegerType(), required=True), + ), + required=False, + ), + NestedField( + field_id=2, + name="s2", + field_type=StructType( + NestedField(field_id=5, name="x", field_type=IntegerType(), required=True), + NestedField(field_id=6, name="y", field_type=IntegerType(), required=True), + ), + required=False, + ), + ), + ) + + with pytest.raises(ValueError) as exc_info: + with tbl.update_schema() as update: + update.move_before("s2.x", "s1.a") + + assert "Cannot move field s2.x to a different struct" in str(exc_info.value) + + +@pytest.mark.integration +def test_add_existing_identifier_fields(catalog: Catalog) -> None: + tbl = _create_table_with_schema( + catalog, + Schema(NestedField(field_id=1, name="foo", field_type=StringType(), required=True), identifier_field_ids=[1]), + ) + + with tbl.update_schema() as update_schema: + update_schema.set_identifier_fields("foo") + + assert tbl.schema().identifier_field_names() == {"foo"} + + +@pytest.mark.integration +def test_add_new_identifiers_field_columns(catalog: Catalog) -> None: + tbl = _create_table_with_schema( + catalog, + Schema(NestedField(field_id=1, name="foo", field_type=StringType(), required=True), identifier_field_ids=[1]), + ) + + with tbl.update_schema(allow_incompatible_changes=True) as update_schema: + update_schema.add_column("new_field", StringType(), required=True) + update_schema.set_identifier_fields("foo", "new_field") + + assert tbl.schema().identifier_field_names() == {"foo", "new_field"} + + +@pytest.mark.integration +def test_add_new_identifiers_field_columns_out_of_order(catalog: Catalog) -> None: + tbl = _create_table_with_schema( + catalog, + Schema(NestedField(field_id=1, name="foo", field_type=StringType(), required=True), identifier_field_ids=[1]), + ) + + with tbl.update_schema(allow_incompatible_changes=True) as update_schema: + update_schema.add_column("new_field", StringType(), required=True) + update_schema.set_identifier_fields("foo", "new_field") + + assert tbl.schema().identifier_field_names() == {"foo", "new_field"} + + +@pytest.mark.integration +def test_add_nested_identifier_field_columns(catalog: Catalog) -> None: + tbl = _create_table_with_schema( + catalog, + Schema(NestedField(field_id=1, name="foo", field_type=StringType(), required=True), identifier_field_ids=[1]), + ) + + with tbl.update_schema(allow_incompatible_changes=True) as update_schema: + update_schema.add_column( + "required_struct", StructType(NestedField(field_id=3, name="field", type=StringType(), required=True)), required=True + ) + + with tbl.update_schema() as update_schema: + update_schema.set_identifier_fields("required_struct.field") + + assert tbl.schema().identifier_field_names() == {"required_struct.field"} + + +@pytest.mark.integration +def test_add_nested_identifier_field_columns_single_transaction(catalog: Catalog) -> None: + tbl = _create_table_with_schema( + catalog, + Schema(NestedField(field_id=1, name="foo", field_type=StringType(), required=True), identifier_field_ids=[1]), + ) + + with tbl.update_schema(allow_incompatible_changes=True) as update_schema: + update_schema.add_column( + "new", StructType(NestedField(field_id=3, name="field", type=StringType(), required=True)), required=True + ) + update_schema.set_identifier_fields("new.field") + + assert tbl.schema().identifier_field_names() == {"new.field"} + + +@pytest.mark.integration +def test_add_nested_nested_identifier_field_columns(catalog: Catalog) -> None: + tbl = _create_table_with_schema( + catalog, + Schema(NestedField(field_id=1, name="foo", field_type=StringType(), required=True), identifier_field_ids=[1]), + ) + + with tbl.update_schema(allow_incompatible_changes=True) as update_schema: + update_schema.add_column( + "new", + StructType( + NestedField( + field_id=3, + name="field", + type=StructType(NestedField(field_id=4, name="nested", type=StringType(), required=True)), + required=True, + ) + ), + required=True, + ) + update_schema.set_identifier_fields("new.field.nested") + + assert tbl.schema().identifier_field_names() == {"new.field.nested"} + + +@pytest.mark.integration +def test_add_dotted_identifier_field_columns(catalog: Catalog) -> None: + tbl = _create_table_with_schema( + catalog, + Schema(NestedField(field_id=1, name="foo", field_type=StringType(), required=True), identifier_field_ids=[1]), + ) + + with tbl.update_schema(allow_incompatible_changes=True) as update_schema: + update_schema.add_column(("dot.field",), StringType(), required=True) + update_schema.set_identifier_fields("dot.field") + + assert tbl.schema().identifier_field_names() == {"dot.field"} + + +@pytest.mark.integration +def test_remove_identifier_fields(catalog: Catalog) -> None: + tbl = _create_table_with_schema( + catalog, + Schema(NestedField(field_id=1, name="foo", field_type=StringType(), required=True), identifier_field_ids=[1]), + ) + + with tbl.update_schema(allow_incompatible_changes=True) as update_schema: + update_schema.add_column(("new_field",), StringType(), required=True) + update_schema.add_column(("new_field2",), StringType(), required=True) + update_schema.set_identifier_fields("foo", "new_field", "new_field2") + + assert tbl.schema().identifier_field_names() == {"foo", "new_field", "new_field2"} + + with tbl.update_schema(allow_incompatible_changes=True) as update_schema: + update_schema.set_identifier_fields() + + assert tbl.schema().identifier_field_names() == set() + + +@pytest.mark.integration +def test_set_identifier_field_fails_schema(catalog: Catalog) -> None: + tbl = _create_table_with_schema( + catalog, + Schema( + NestedField(field_id=1, name="id", field_type=IntegerType(), required=False), + NestedField(field_id=2, name="float", field_type=FloatType(), required=True), + NestedField(field_id=3, name="double", field_type=DoubleType(), required=True), + identifier_field_ids=[], + ), + ) + + with pytest.raises(ValueError) as exc_info: + with tbl.update_schema() as update_schema: + update_schema.set_identifier_fields("id") + + assert "Identifier field 1 invalid: not a required field" in str(exc_info.value) + + with pytest.raises(ValueError) as exc_info: + with tbl.update_schema() as update_schema: + update_schema.set_identifier_fields("float") + + assert "Identifier field 2 invalid: must not be float or double field" in str(exc_info.value) + + with pytest.raises(ValueError) as exc_info: + with tbl.update_schema() as update_schema: + update_schema.set_identifier_fields("double") + + assert "Identifier field 3 invalid: must not be float or double field" in str(exc_info.value) + + with pytest.raises(ValueError) as exc_info: + with tbl.update_schema() as update_schema: + update_schema.set_identifier_fields("unknown") + + assert "Cannot find identifier field unknown. In case of deletion, update the identifier fields first." in str(exc_info.value) + + +@pytest.mark.integration +def test_set_identifier_field_fails(nested_table: Table) -> None: + with pytest.raises(ValueError) as exc_info: + with nested_table.update_schema() as update_schema: + update_schema.set_identifier_fields("location") + + assert "Identifier field 6 invalid: not a primitive type field" in str(exc_info.value) + + with pytest.raises(ValueError) as exc_info: + with nested_table.update_schema() as update_schema: + update_schema.set_identifier_fields("baz") + + assert "Identifier field 3 invalid: not a required field" in str(exc_info.value) + + with pytest.raises(ValueError) as exc_info: + with nested_table.update_schema() as update_schema: + update_schema.set_identifier_fields("person.name") + + assert "Identifier field 16 invalid: not a required field" in str(exc_info.value) + + +@pytest.mark.integration +def test_delete_identifier_field_columns(catalog: Catalog) -> None: + tbl = _create_table_with_schema( + catalog, + Schema(NestedField(field_id=1, name="foo", field_type=StringType(), required=True), identifier_field_ids=[1]), + ) + + with tbl.update_schema() as schema_update: + schema_update.delete_column("foo") + schema_update.set_identifier_fields() + + tbl = _create_table_with_schema( + catalog, + Schema(NestedField(field_id=1, name="foo", field_type=StringType(), required=True), identifier_field_ids=[1]), + ) + + with tbl.update_schema() as schema_update: + schema_update.set_identifier_fields() + schema_update.delete_column("foo") + + +@pytest.mark.integration +def test_delete_containing_nested_identifier_field_columns_fails(catalog: Catalog) -> None: + tbl = _create_table_with_schema( + catalog, + Schema(NestedField(field_id=1, name="foo", field_type=StringType(), required=True), identifier_field_ids=[1]), + ) + + with tbl.update_schema(allow_incompatible_changes=True) as schema_update: + schema_update.add_column( + "out", StructType(NestedField(field_id=3, name="nested", field_type=StringType(), required=True)), required=True + ) + schema_update.set_identifier_fields("out.nested") + + assert tbl.schema() == Schema( + NestedField(field_id=1, name="foo", field_type=StringType(), required=True), + NestedField( + field_id=2, + name="out", + field_type=StructType(NestedField(field_id=3, name="nested", field_type=StringType(), required=True)), + required=True, + ), + identifier_field_ids=[3], + ) + + with pytest.raises(ValueError) as exc_info: + with tbl.update_schema() as schema_update: + schema_update.delete_column("out") + + assert "Cannot find identifier field out.nested. In case of deletion, update the identifier fields first." in str(exc_info) + + +@pytest.mark.integration +def test_rename_identifier_fields(catalog: Catalog) -> None: + tbl = _create_table_with_schema( + catalog, + Schema(NestedField(field_id=1, name="foo", field_type=StringType(), required=True), identifier_field_ids=[1]), + ) + + with tbl.update_schema() as schema_update: + schema_update.rename_column("foo", "bar") + + assert tbl.schema().identifier_field_ids == [1] + assert tbl.schema().identifier_field_names() == {"bar"} + + +@pytest.mark.integration +def test_move_identifier_fields(catalog: Catalog) -> None: + tbl = _create_table_with_schema( + catalog, + Schema( + NestedField(field_id=1, name="id", field_type=LongType(), required=True), + NestedField(field_id=2, name="data", field_type=StringType(), required=True), + identifier_field_ids=[1], + ), + ) + + with tbl.update_schema() as update: + update.move_before("data", "id") + + assert tbl.schema().identifier_field_ids == [1] + assert tbl.schema().identifier_field_names() == {"id"} + + with tbl.update_schema() as update: + update.move_after("id", "data") + + assert tbl.schema().identifier_field_ids == [1] + assert tbl.schema().identifier_field_names() == {"id"} + + with tbl.update_schema() as update: + update.move_first("data") + + assert tbl.schema().identifier_field_ids == [1] + assert tbl.schema().identifier_field_names() == {"id"} + + +@pytest.mark.integration +def test_move_identifier_fields_case_insensitive(catalog: Catalog) -> None: + tbl = _create_table_with_schema( + catalog, + Schema( + NestedField(field_id=1, name="id", field_type=LongType(), required=True), + NestedField(field_id=2, name="data", field_type=StringType(), required=True), + identifier_field_ids=[1], + ), + ) + + with tbl.update_schema(case_sensitive=False) as update: + update.move_before("DATA", "ID") + + assert tbl.schema().identifier_field_ids == [1] + assert tbl.schema().identifier_field_names() == {"id"} + + with tbl.update_schema(case_sensitive=False) as update: + update.move_after("ID", "DATA") + + assert tbl.schema().identifier_field_ids == [1] + assert tbl.schema().identifier_field_names() == {"id"} + + with tbl.update_schema(case_sensitive=False) as update: + update.move_first("DATA") + + assert tbl.schema().identifier_field_ids == [1] + assert tbl.schema().identifier_field_names() == {"id"} + + +@pytest.mark.integration +def test_two_add_schemas_in_a_single_transaction(catalog: Catalog) -> None: + tbl = _create_table_with_schema( + catalog, + Schema( + NestedField(field_id=1, name="foo", field_type=StringType()), + ), + ) + + with pytest.raises(ValueError) as exc_info: + with tbl.transaction() as tr: + with tr.update_schema() as update: + update.add_column("bar", field_type=StringType()) + with tr.update_schema() as update: + update.add_column("baz", field_type=StringType()) + + assert "Updates in a single commit need to be unique, duplicate: " in str( + exc_info.value + )