diff --git a/python/poetry.lock b/python/poetry.lock index bc47c6e866b9..307f529dd673 100644 --- a/python/poetry.lock +++ b/python/poetry.lock @@ -251,6 +251,21 @@ category = "main" optional = true python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +[[package]] +name = "pydantic" +version = "1.9.1" +description = "Data validation and settings management using python type hints" +category = "main" +optional = false +python-versions = ">=3.6.1" + +[package.dependencies] +typing-extensions = ">=3.7.4.3" + +[package.extras] +dotenv = ["python-dotenv (>=0.10.4)"] +email = ["email-validator (>=1.0.3)"] + [[package]] name = "pyparsing" version = "3.0.9" @@ -340,6 +355,14 @@ category = "dev" optional = false python-versions = ">=3.7" +[[package]] +name = "typing-extensions" +version = "4.2.0" +description = "Backported and Experimental Type Hints for Python 3.7+" +category = "main" +optional = false +python-versions = ">=3.7" + [[package]] name = "virtualenv" version = "20.14.1" @@ -392,7 +415,7 @@ snappy = ["python-snappy"] [metadata] lock-version = "1.1" python-versions = "^3.8" -content-hash = "439429e65911b8e768bd4617a4883e8a4ec2652df3a43b28755346da8ed17e19" +content-hash = "c60cc02a89d6d22566af086e4f94dcf12662a59d0a9ceacbfbd90c65b5b34308" [metadata.files] atomicwrites = [ @@ -661,6 +684,43 @@ pycparser = [ {file = "pycparser-2.21-py2.py3-none-any.whl", hash = "sha256:8ee45429555515e1f6b185e78100aea234072576aa43ab53aefcae078162fca9"}, {file = "pycparser-2.21.tar.gz", hash = "sha256:e644fdec12f7872f86c58ff790da456218b10f863970249516d60a5eaca77206"}, ] +pydantic = [ + {file = "pydantic-1.9.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c8098a724c2784bf03e8070993f6d46aa2eeca031f8d8a048dff277703e6e193"}, + {file = "pydantic-1.9.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c320c64dd876e45254bdd350f0179da737463eea41c43bacbee9d8c9d1021f11"}, + {file = "pydantic-1.9.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:18f3e912f9ad1bdec27fb06b8198a2ccc32f201e24174cec1b3424dda605a310"}, + {file = "pydantic-1.9.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c11951b404e08b01b151222a1cb1a9f0a860a8153ce8334149ab9199cd198131"}, + {file = "pydantic-1.9.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:8bc541a405423ce0e51c19f637050acdbdf8feca34150e0d17f675e72d119580"}, + {file = "pydantic-1.9.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:e565a785233c2d03724c4dc55464559639b1ba9ecf091288dd47ad9c629433bd"}, + {file = "pydantic-1.9.1-cp310-cp310-win_amd64.whl", hash = "sha256:a4a88dcd6ff8fd47c18b3a3709a89adb39a6373f4482e04c1b765045c7e282fd"}, + {file = "pydantic-1.9.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:447d5521575f18e18240906beadc58551e97ec98142266e521c34968c76c8761"}, + {file = "pydantic-1.9.1-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:985ceb5d0a86fcaa61e45781e567a59baa0da292d5ed2e490d612d0de5796918"}, + {file = "pydantic-1.9.1-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:059b6c1795170809103a1538255883e1983e5b831faea6558ef873d4955b4a74"}, + {file = "pydantic-1.9.1-cp36-cp36m-musllinux_1_1_i686.whl", hash = "sha256:d12f96b5b64bec3f43c8e82b4aab7599d0157f11c798c9f9c528a72b9e0b339a"}, + {file = "pydantic-1.9.1-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:ae72f8098acb368d877b210ebe02ba12585e77bd0db78ac04a1ee9b9f5dd2166"}, + {file = "pydantic-1.9.1-cp36-cp36m-win_amd64.whl", hash = "sha256:79b485767c13788ee314669008d01f9ef3bc05db9ea3298f6a50d3ef596a154b"}, + {file = "pydantic-1.9.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:494f7c8537f0c02b740c229af4cb47c0d39840b829ecdcfc93d91dcbb0779892"}, + {file = "pydantic-1.9.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f0f047e11febe5c3198ed346b507e1d010330d56ad615a7e0a89fae604065a0e"}, + {file = "pydantic-1.9.1-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:969dd06110cb780da01336b281f53e2e7eb3a482831df441fb65dd30403f4608"}, + {file = "pydantic-1.9.1-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:177071dfc0df6248fd22b43036f936cfe2508077a72af0933d0c1fa269b18537"}, + {file = "pydantic-1.9.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:9bcf8b6e011be08fb729d110f3e22e654a50f8a826b0575c7196616780683380"}, + {file = "pydantic-1.9.1-cp37-cp37m-win_amd64.whl", hash = "sha256:a955260d47f03df08acf45689bd163ed9df82c0e0124beb4251b1290fa7ae728"}, + {file = "pydantic-1.9.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:9ce157d979f742a915b75f792dbd6aa63b8eccaf46a1005ba03aa8a986bde34a"}, + {file = "pydantic-1.9.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:0bf07cab5b279859c253d26a9194a8906e6f4a210063b84b433cf90a569de0c1"}, + {file = "pydantic-1.9.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5d93d4e95eacd313d2c765ebe40d49ca9dd2ed90e5b37d0d421c597af830c195"}, + {file = "pydantic-1.9.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1542636a39c4892c4f4fa6270696902acb186a9aaeac6f6cf92ce6ae2e88564b"}, + {file = "pydantic-1.9.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:a9af62e9b5b9bc67b2a195ebc2c2662fdf498a822d62f902bf27cccb52dbbf49"}, + {file = "pydantic-1.9.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:fe4670cb32ea98ffbf5a1262f14c3e102cccd92b1869df3bb09538158ba90fe6"}, + {file = "pydantic-1.9.1-cp38-cp38-win_amd64.whl", hash = "sha256:9f659a5ee95c8baa2436d392267988fd0f43eb774e5eb8739252e5a7e9cf07e0"}, + {file = "pydantic-1.9.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:b83ba3825bc91dfa989d4eed76865e71aea3a6ca1388b59fc801ee04c4d8d0d6"}, + {file = "pydantic-1.9.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:1dd8fecbad028cd89d04a46688d2fcc14423e8a196d5b0a5c65105664901f810"}, + {file = "pydantic-1.9.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:02eefd7087268b711a3ff4db528e9916ac9aa18616da7bca69c1871d0b7a091f"}, + {file = "pydantic-1.9.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7eb57ba90929bac0b6cc2af2373893d80ac559adda6933e562dcfb375029acee"}, + {file = "pydantic-1.9.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:4ce9ae9e91f46c344bec3b03d6ee9612802682c1551aaf627ad24045ce090761"}, + {file = "pydantic-1.9.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:72ccb318bf0c9ab97fc04c10c37683d9eea952ed526707fabf9ac5ae59b701fd"}, + {file = "pydantic-1.9.1-cp39-cp39-win_amd64.whl", hash = "sha256:61b6760b08b7c395975d893e0b814a11cf011ebb24f7d869e7118f5a339a82e1"}, + {file = "pydantic-1.9.1-py3-none-any.whl", hash = "sha256:4988c0f13c42bfa9ddd2fe2f569c9d54646ce84adc5de84228cfe83396f3bd58"}, + {file = "pydantic-1.9.1.tar.gz", hash = "sha256:1ed987c3ff29fff7fd8c3ea3a3ea877ad310aae2ef9889a119e22d3f2db0691a"}, +] pyparsing = [ {file = "pyparsing-3.0.9-py3-none-any.whl", hash = "sha256:5026bae9a10eeaefb61dab2f09052b9f4307d44aee4eda64b309723d8d206bbc"}, {file = "pyparsing-3.0.9.tar.gz", hash = "sha256:2b020ecf7d21b687f219b71ecad3631f644a47f01403fa1d1036b0c6416d70fb"}, @@ -770,6 +830,10 @@ tomli = [ {file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"}, {file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"}, ] +typing-extensions = [ + {file = "typing_extensions-4.2.0-py3-none-any.whl", hash = "sha256:6657594ee297170d19f67d55c05852a874e7eb634f4f753dbd667855e07c1708"}, + {file = "typing_extensions-4.2.0.tar.gz", hash = "sha256:f1c24655a0da0d1b67f07e17a5e6b2a105894e6824b92096378bb3668ef02376"}, +] virtualenv = [ {file = "virtualenv-20.14.1-py2.py3-none-any.whl", hash = "sha256:e617f16e25b42eb4f6e74096b9c9e37713cf10bf30168fb4a739f3fa8f898a3a"}, {file = "virtualenv-20.14.1.tar.gz", hash = "sha256:ef589a79795589aada0c1c5b319486797c03b67ac3984c48c669c0e4f50df3a5"}, diff --git a/python/pyproject.toml b/python/pyproject.toml index 5d7b75f868a9..07e5ff12edc8 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -40,6 +40,8 @@ packages = [ [tool.poetry.dependencies] python = "^3.8" mmh3 = "^3.0.0" +pydantic = "^1.9.1" + pyarrow = { version = "^8.0.0", optional = true } zstandard = { version = "^0.17.0", optional = true } diff --git a/python/src/iceberg/exceptions.py b/python/src/iceberg/exceptions.py index 3dd68e55d443..b12e836e46ce 100644 --- a/python/src/iceberg/exceptions.py +++ b/python/src/iceberg/exceptions.py @@ -30,3 +30,7 @@ class NamespaceNotEmptyError(Exception): class AlreadyExistsError(Exception): """Raised when a table or name-space being created already exists in the catalog""" + + +class ValidationError(Exception): + ... diff --git a/python/src/iceberg/schema.py b/python/src/iceberg/schema.py index 99ef37c4b9df..b57da3679052 100644 --- a/python/src/iceberg/schema.py +++ b/python/src/iceberg/schema.py @@ -16,8 +16,6 @@ # under the License. # pylint: disable=W0511 -from __future__ import annotations - from abc import ABC, abstractmethod from dataclasses import dataclass from functools import cached_property, singledispatch @@ -25,9 +23,15 @@ Any, Dict, Generic, + List, + Optional, + Tuple, TypeVar, + Union, ) +from pydantic import Field, PrivateAttr + from iceberg.files import StructProtocol from iceberg.types import ( IcebergType, @@ -37,11 +41,12 @@ PrimitiveType, StructType, ) +from iceberg.utils.iceberg_base_model import IcebergBaseModel T = TypeVar("T") -class Schema: +class Schema(IcebergBaseModel): """A table Schema Example: @@ -49,11 +54,17 @@ class Schema: >>> from iceberg import types """ - def __init__(self, *columns: NestedField, schema_id: int, identifier_field_ids: list[int] | None = None): - self._struct = StructType(*columns) - self._schema_id = schema_id - self._identifier_field_ids = identifier_field_ids or [] - self._name_to_id: dict[str, int] = index_by_name(self) + fields: Tuple[NestedField, ...] = Field(default_factory=tuple) + schema_id: int = Field(alias="schema-id") + identifier_field_ids: List[int] = Field(alias="identifier-field-ids", default_factory=list) + + _name_to_id: Dict[str, int] = PrivateAttr() + + def __init__(self, *fields: NestedField, **data): + if fields: + data["fields"] = fields + super().__init__(**data) + self._name_to_id = index_by_name(self) def __str__(self): return "table {\n" + "\n".join([" " + str(field) for field in self.columns]) + "\n}" @@ -79,21 +90,12 @@ def __eq__(self, other) -> bool: return identifier_field_ids_is_equal and schema_is_equal @property - def columns(self) -> tuple[NestedField, ...]: - """A list of the top-level fields in the underlying struct""" - return self._struct.fields - - @property - def schema_id(self) -> int: - """The ID of this Schema""" - return self._schema_id - - @property - def identifier_field_ids(self) -> list[int]: - return self._identifier_field_ids + def columns(self) -> Tuple[NestedField, ...]: + """A tuple of the top-level fields""" + return self.fields @cached_property - def _lazy_id_to_field(self) -> dict[int, NestedField]: + def _lazy_id_to_field(self) -> Dict[int, NestedField]: """Returns an index of field ID to NestedField instance This is calculated once when called for the first time. Subsequent calls to this method will use a cached index. @@ -101,7 +103,7 @@ def _lazy_id_to_field(self) -> dict[int, NestedField]: return index_by_id(self) @cached_property - def _lazy_name_to_id_lower(self) -> dict[str, int]: + def _lazy_name_to_id_lower(self) -> Dict[str, int]: """Returns an index of lower-case field names to field IDs This is calculated once when called for the first time. Subsequent calls to this method will use a cached index. @@ -109,7 +111,7 @@ def _lazy_name_to_id_lower(self) -> dict[str, int]: return {name.lower(): field_id for name, field_id in self._name_to_id.items()} @cached_property - def _lazy_id_to_name(self) -> dict[int, str]: + def _lazy_id_to_name(self) -> Dict[int, str]: """Returns an index of field ID to full name This is calculated once when called for the first time. Subsequent calls to this method will use a cached index. @@ -117,7 +119,7 @@ def _lazy_id_to_name(self) -> dict[int, str]: return index_name_by_id(self) @cached_property - def _lazy_id_to_accessor(self) -> dict[int, Accessor]: + def _lazy_id_to_accessor(self) -> Dict[int, "Accessor"]: """Returns an index of field ID to accessor This is calculated once when called for the first time. Subsequent calls to this method will use a cached index. @@ -125,10 +127,10 @@ def _lazy_id_to_accessor(self) -> dict[int, Accessor]: return build_position_accessors(self) def as_struct(self) -> StructType: - """Returns the underlying struct""" - return self._struct + """Returns the schema as a struct""" + return StructType(*self.fields) - def find_field(self, name_or_id: str | int, case_sensitive: bool = True) -> NestedField | None: + def find_field(self, name_or_id: Union[str, int], case_sensitive: bool = True) -> Optional[NestedField]: """Find a field using a field name or field ID Args: @@ -146,7 +148,7 @@ def find_field(self, name_or_id: str | int, case_sensitive: bool = True) -> Nest field_id = self._lazy_name_to_id_lower.get(name_or_id.lower()) return self._lazy_id_to_field.get(field_id) # type: ignore - def find_type(self, name_or_id: str | int, case_sensitive: bool = True) -> IcebergType: + def find_type(self, name_or_id: Union[str, int], case_sensitive: bool = True) -> IcebergType: """Find a field type using a field name or field ID Args: @@ -161,7 +163,7 @@ def find_type(self, name_or_id: str | int, case_sensitive: bool = True) -> Icebe raise ValueError(f"Could not find field with name or id {name_or_id}, case_sensitive={case_sensitive}") return field.field_type - def find_column_name(self, column_id: int) -> str | None: + def find_column_name(self, column_id: int) -> Optional[str]: """Find a column name given a column ID Args: @@ -172,7 +174,7 @@ def find_column_name(self, column_id: int) -> str | None: """ return self._lazy_id_to_name.get(column_id) - def accessor_for_field(self, field_id: int) -> Accessor | None: + def accessor_for_field(self, field_id: int) -> Optional["Accessor"]: """Find a schema position accessor given a field ID Args: @@ -183,7 +185,7 @@ def accessor_for_field(self, field_id: int) -> Accessor | None: """ return self._lazy_id_to_accessor.get(field_id) - def select(self, names: list[str], case_sensitive: bool = True) -> Schema: + def select(self, names: List[str], case_sensitive: bool = True) -> "Schema": """Return a new schema instance pruned to a subset of columns Args: @@ -198,12 +200,12 @@ def select(self, names: list[str], case_sensitive: bool = True) -> Schema: return self._case_insensitive_select(schema=self, names=names) @classmethod - def _case_sensitive_select(cls, schema: Schema, names: list[str]): + def _case_sensitive_select(cls, schema: "Schema", names: List[str]): # TODO: Add a PruneColumns schema visitor and use it here raise NotImplementedError() @classmethod - def _case_insensitive_select(cls, schema: Schema, names: list[str]): + def _case_insensitive_select(cls, schema: "Schema", names: List[str]): # TODO: Add a PruneColumns schema visitor and use it here raise NotImplementedError() @@ -244,7 +246,7 @@ def schema(self, schema: Schema, struct_result: T) -> T: """Visit a Schema""" @abstractmethod - def struct(self, struct: StructType, field_results: list[T]) -> T: + def struct(self, struct: StructType, field_results: List[T]) -> T: """Visit a StructType""" @abstractmethod @@ -269,7 +271,7 @@ class Accessor: """An accessor for a specific position in a container that implements the StructProtocol""" position: int - inner: Accessor | None = None + inner: Optional["Accessor"] = None def __str__(self): return f"Accessor(position={self.position},inner={self.inner})" @@ -336,9 +338,9 @@ def _(obj: StructType, visitor: SchemaVisitor[T]) -> T: def _(obj: ListType, visitor: SchemaVisitor[T]) -> T: """Visit a ListType with a concrete SchemaVisitor""" - visitor.before_list_element(obj.element) - result = visit(obj.element.field_type, visitor) - visitor.after_list_element(obj.element) + visitor.before_list_element(obj.element_field) + result = visit(obj.element_type, visitor) + visitor.after_list_element(obj.element_field) return visitor.list(obj, result) @@ -346,13 +348,13 @@ def _(obj: ListType, visitor: SchemaVisitor[T]) -> T: @visit.register(MapType) def _(obj: MapType, visitor: SchemaVisitor[T]) -> T: """Visit a MapType with a concrete SchemaVisitor""" - visitor.before_map_key(obj.key) - key_result = visit(obj.key.field_type, visitor) - visitor.after_map_key(obj.key) + visitor.before_map_key(obj.key_field) + key_result = visit(obj.key_type, visitor) + visitor.after_map_key(obj.key_field) - visitor.before_map_value(obj.value) - value_result = visit(obj.value.field_type, visitor) - visitor.after_list_element(obj.value) + visitor.before_map_value(obj.value_field) + value_result = visit(obj.value_type, visitor) + visitor.after_list_element(obj.value_field) return visitor.map(obj, key_result, value_result) @@ -367,35 +369,35 @@ class _IndexById(SchemaVisitor[Dict[int, NestedField]]): """A schema visitor for generating a field ID to NestedField index""" def __init__(self) -> None: - self._index: dict[int, NestedField] = {} + self._index: Dict[int, NestedField] = {} - def schema(self, schema: Schema, struct_result) -> dict[int, NestedField]: + def schema(self, schema: Schema, struct_result) -> Dict[int, NestedField]: return self._index - def struct(self, struct: StructType, field_results) -> dict[int, NestedField]: + def struct(self, struct: StructType, field_results) -> Dict[int, NestedField]: return self._index - def field(self, field: NestedField, field_result) -> dict[int, NestedField]: + def field(self, field: NestedField, field_result) -> Dict[int, NestedField]: """Add the field ID to the index""" self._index[field.field_id] = field return self._index - def list(self, list_type: ListType, element_result) -> dict[int, NestedField]: + def list(self, list_type: ListType, element_result) -> Dict[int, NestedField]: """Add the list element ID to the index""" - self._index[list_type.element.field_id] = list_type.element + self._index[list_type.element_field.field_id] = list_type.element_field return self._index - def map(self, map_type: MapType, key_result, value_result) -> dict[int, NestedField]: + def map(self, map_type: MapType, key_result, value_result) -> Dict[int, NestedField]: """Add the key ID and value ID as individual items in the index""" - self._index[map_type.key.field_id] = map_type.key - self._index[map_type.value.field_id] = map_type.value + self._index[map_type.key_field.field_id] = map_type.key_field + self._index[map_type.value_field.field_id] = map_type.value_field return self._index - def primitive(self, primitive) -> dict[int, NestedField]: + def primitive(self, primitive) -> Dict[int, NestedField]: return self._index -def index_by_id(schema_or_type) -> dict[int, NestedField]: +def index_by_id(schema_or_type) -> Dict[int, NestedField]: """Generate an index of field IDs to NestedField instances Args: @@ -411,11 +413,11 @@ class _IndexByName(SchemaVisitor[Dict[str, int]]): """A schema visitor for generating a field name to field ID index""" def __init__(self) -> None: - self._index: dict[str, int] = {} - self._short_name_to_id: dict[str, int] = {} - self._combined_index: dict[str, int] = {} - self._field_names: list[str] = [] - self._short_field_names: list[str] = [] + self._index: Dict[str, int] = {} + self._short_name_to_id: Dict[str, int] = {} + self._combined_index: Dict[str, int] = {} + self._field_names: List[str] = [] + self._short_field_names: List[str] = [] def before_list_element(self, element: NestedField) -> None: """Short field names omit element when the element is a StructType""" @@ -438,26 +440,26 @@ def after_field(self, field: NestedField) -> None: self._field_names.pop() self._short_field_names.pop() - def schema(self, schema: Schema, struct_result: dict[str, int]) -> dict[str, int]: + def schema(self, schema: Schema, struct_result: Dict[str, int]) -> Dict[str, int]: return self._index - def struct(self, struct: StructType, field_results: list[dict[str, int]]) -> dict[str, int]: + def struct(self, struct: StructType, field_results: List[Dict[str, int]]) -> Dict[str, int]: return self._index - def field(self, field: NestedField, field_result: dict[str, int]) -> dict[str, int]: + def field(self, field: NestedField, field_result: Dict[str, int]) -> Dict[str, int]: """Add the field name to the index""" self._add_field(field.name, field.field_id) return self._index - def list(self, list_type: ListType, element_result: dict[str, int]) -> dict[str, int]: + def list(self, list_type: ListType, element_result: Dict[str, int]) -> Dict[str, int]: """Add the list element name to the index""" - self._add_field(list_type.element.name, list_type.element.field_id) + self._add_field(list_type.element_field.name, list_type.element_field.field_id) return self._index - def map(self, map_type: MapType, key_result: dict[str, int], value_result: dict[str, int]) -> dict[str, int]: + def map(self, map_type: MapType, key_result: Dict[str, int], value_result: Dict[str, int]) -> Dict[str, int]: """Add the key name and value name as individual items in the index""" - self._add_field(map_type.key.name, map_type.key.field_id) - self._add_field(map_type.value.name, map_type.value.field_id) + self._add_field(map_type.key_field.name, map_type.key_field.field_id) + self._add_field(map_type.value_field.name, map_type.value_field.field_id) return self._index def _add_field(self, name: str, field_id: int): @@ -483,10 +485,10 @@ def _add_field(self, name: str, field_id: int): short_name = ".".join([".".join(self._short_field_names), name]) self._short_name_to_id[short_name] = field_id - def primitive(self, primitive) -> dict[str, int]: + def primitive(self, primitive) -> Dict[str, int]: return self._index - def by_name(self) -> dict[str, int]: + def by_name(self) -> Dict[str, int]: """Returns an index of combined full and short names Note: Only short names that do not conflict with full names are included. @@ -495,13 +497,13 @@ def by_name(self) -> dict[str, int]: combined_index.update(self._index) return combined_index - def by_id(self) -> dict[int, str]: + def by_id(self) -> Dict[int, str]: """Returns an index of ID to full names""" id_to_full_name = {value: key for key, value in self._index.items()} return id_to_full_name -def index_by_name(schema_or_type: Schema | IcebergType) -> dict[str, int]: +def index_by_name(schema_or_type: Union[Schema, IcebergType]) -> Dict[str, int]: """Generate an index of field names to field IDs Args: @@ -515,7 +517,7 @@ def index_by_name(schema_or_type: Schema | IcebergType) -> dict[str, int]: return indexer.by_name() -def index_name_by_id(schema_or_type: Schema | IcebergType) -> dict[int, str]: +def index_name_by_id(schema_or_type: Union[Schema, IcebergType]) -> Dict[int, str]: """Generate an index of field IDs full field names Args: @@ -565,13 +567,13 @@ class _BuildPositionAccessors(SchemaVisitor[Dict[Position, Accessor]]): """ @staticmethod - def _wrap_leaves(result: dict[Position, Accessor], position: Position = 0) -> dict[Position, Accessor]: + def _wrap_leaves(result: Dict[Position, Accessor], position: Position = 0) -> Dict[Position, Accessor]: return {field_id: Accessor(position, inner=inner) for field_id, inner in result.items()} - def schema(self, schema: Schema, struct_result: dict[Position, Accessor]) -> dict[Position, Accessor]: + def schema(self, schema: Schema, struct_result: Dict[Position, Accessor]) -> Dict[Position, Accessor]: return struct_result - def struct(self, struct: StructType, field_results: list[dict[Position, Accessor]]) -> dict[Position, Accessor]: + def struct(self, struct: StructType, field_results: List[Dict[Position, Accessor]]) -> Dict[Position, Accessor]: result = {} for position, field in enumerate(struct.fields): @@ -583,22 +585,22 @@ def struct(self, struct: StructType, field_results: list[dict[Position, Accessor return result - def field(self, field: NestedField, field_result: dict[Position, Accessor]) -> dict[Position, Accessor]: + def field(self, field: NestedField, field_result: Dict[Position, Accessor]) -> Dict[Position, Accessor]: return field_result - def list(self, list_type: ListType, element_result: dict[Position, Accessor]) -> dict[Position, Accessor]: + def list(self, list_type: ListType, element_result: Dict[Position, Accessor]) -> Dict[Position, Accessor]: return {} def map( - self, map_type: MapType, key_result: dict[Position, Accessor], value_result: dict[Position, Accessor] - ) -> dict[Position, Accessor]: + self, map_type: MapType, key_result: Dict[Position, Accessor], value_result: Dict[Position, Accessor] + ) -> Dict[Position, Accessor]: return {} - def primitive(self, primitive: PrimitiveType) -> dict[Position, Accessor]: + def primitive(self, primitive: PrimitiveType) -> Dict[Position, Accessor]: return {} -def build_position_accessors(schema_or_type: Schema | IcebergType) -> dict[int, Accessor]: +def build_position_accessors(schema_or_type: Union[Schema, IcebergType]) -> Dict[int, Accessor]: """Generate an index of field IDs to schema position accessors Args: diff --git a/python/src/iceberg/serializers.py b/python/src/iceberg/serializers.py new file mode 100644 index 000000000000..98e279f6240a --- /dev/null +++ b/python/src/iceberg/serializers.py @@ -0,0 +1,75 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import codecs +import json +from typing import Union + +from iceberg.io.base import InputFile, InputStream, OutputFile +from iceberg.table.metadata import TableMetadata, TableMetadataV1, TableMetadataV2 + + +class FromByteStream: + """A collection of methods that deserialize dictionaries into Iceberg objects""" + + @staticmethod + def table_metadata(byte_stream: InputStream, encoding: str = "utf-8") -> TableMetadata: + """Instantiate a TableMetadata object from a byte stream + + Args: + byte_stream: A file-like byte stream object + encoding (default "utf-8"): The byte encoder to use for the reader + """ + reader = codecs.getreader(encoding) + metadata = json.load(reader(byte_stream)) # type: ignore + return TableMetadata.parse_obj(metadata) # type: ignore + + +class FromInputFile: + """A collection of methods that deserialize InputFiles into Iceberg objects""" + + @staticmethod + def table_metadata(input_file: InputFile, encoding: str = "utf-8") -> TableMetadata: + """Create a TableMetadata instance from an input file + + Args: + input_file (InputFile): A custom implementation of the iceberg.io.file.InputFile abstract base class + encoding (str): Encoding to use when loading bytestream + + Returns: + TableMetadata: A table metadata instance + + """ + return FromByteStream.table_metadata(byte_stream=input_file.open(), encoding=encoding) + + +class ToOutputFile: + """A collection of methods that serialize Iceberg objects into files given an OutputFile instance""" + + @staticmethod + def table_metadata( + metadata: Union[TableMetadataV1, TableMetadataV2], output_file: OutputFile, overwrite: bool = False + ) -> None: + """Write a TableMetadata instance to an output file + + Args: + output_file (OutputFile): A custom implementation of the iceberg.io.file.OutputFile abstract base class + overwrite (bool): Where to overwrite the file if it already exists. Defaults to `False`. + """ + f = output_file.create(overwrite=overwrite) + f.write(metadata.json().encode("utf-8")) + f.close() diff --git a/python/src/iceberg/table/metadata.py b/python/src/iceberg/table/metadata.py new file mode 100644 index 000000000000..55c43c1cf6cd --- /dev/null +++ b/python/src/iceberg/table/metadata.py @@ -0,0 +1,356 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from copy import copy +from typing import ( + Any, + Dict, + List, + Literal, + Optional, + Union, +) +from uuid import UUID, uuid4 + +from pydantic import Field, root_validator + +from iceberg.exceptions import ValidationError +from iceberg.schema import Schema +from iceberg.table.refs import MAIN_BRANCH, SnapshotRef, SnapshotRefType +from iceberg.utils.iceberg_base_model import IcebergBaseModel + +_INITIAL_SEQUENCE_NUMBER = 0 +INITIAL_SPEC_ID = 0 +DEFAULT_SCHEMA_ID = 0 +DEFAULT_SORT_ORDER_UNSORTED = 0 + + +def check_schemas(values: Dict[str, Any]) -> Dict[str, Any]: + """Validator to check if the current-schema-id is actually present in schemas""" + current_schema_id = values["current_schema_id"] + + for schema in values["schemas"]: + if schema.schema_id == current_schema_id: + return values + + raise ValidationError(f"current-schema-id {current_schema_id} can't be found in the schemas") + + +def check_partition_specs(values: Dict[str, Any]) -> Dict[str, Any]: + """Validator to check if the default-spec-id is present in partition-specs""" + default_spec_id = values["default_spec_id"] + + for spec in values["partition_specs"]: + if spec["spec-id"] == default_spec_id: + return values + + raise ValidationError(f"default-spec-id {default_spec_id} can't be found") + + +def check_sort_orders(values: Dict[str, Any]) -> Dict[str, Any]: + """Validator to check if the default_sort_order_id is present in sort-orders""" + default_sort_order_id = values["default_sort_order_id"] + + if default_sort_order_id != DEFAULT_SORT_ORDER_UNSORTED: + for sort in values["sort_orders"]: + if sort["order-id"] == default_sort_order_id: + return values + + raise ValidationError(f"default-sort-order-id {default_sort_order_id} can't be found") + return values + + +class TableMetadataCommonFields(IcebergBaseModel): + """Metadata for an Iceberg table as specified in the Apache Iceberg + spec (https://iceberg.apache.org/spec/#iceberg-table-spec)""" + + @root_validator(pre=True) + def cleanup_snapshot_id(cls, data: Dict[str, Any]): + if data.get("current-snapshot-id") == -1: + # We treat -1 and None the same, by cleaning this up + # in a pre-validator, we can simplify the logic later on + data["current-snapshot-id"] = None + return data + + @root_validator(skip_on_failure=True) + def construct_refs(cls, data: Dict[str, Any]): + # This is going to be much nicer as soon as refs is an actual pydantic object + if current_snapshot_id := data.get("current_snapshot_id"): + data["refs"] = {MAIN_BRANCH: SnapshotRef(snapshot_id=current_snapshot_id, snapshot_ref_type=SnapshotRefType.BRANCH)} + return data + + location: str = Field() + """The table’s base location. This is used by writers to determine where + to store data files, manifest files, and table metadata files.""" + + table_uuid: Optional[UUID] = Field(alias="table-uuid") + """A UUID that identifies the table, generated when the table is created. + Implementations must throw an exception if a table’s UUID does not match + the expected UUID after refreshing metadata.""" + + last_updated_ms: int = Field(alias="last-updated-ms") + """Timestamp in milliseconds from the unix epoch when the table + was last updated. Each table metadata file should update this + field just before writing.""" + + last_column_id: int = Field(alias="last-column-id") + """An integer; the highest assigned column ID for the table. + This is used to ensure fields are always assigned an unused ID + when evolving schemas.""" + + schemas: List[Schema] = Field(default_factory=list) + """A list of schemas, stored as objects with schema-id.""" + + current_schema_id: int = Field(alias="current-schema-id", default=DEFAULT_SCHEMA_ID) + """ID of the table’s current schema.""" + + partition_specs: list = Field(alias="partition-specs", default_factory=list) + """A list of partition specs, stored as full partition spec objects.""" + + default_spec_id: int = Field(alias="default-spec-id", default=INITIAL_SPEC_ID) + """ID of the “current” spec that writers should use by default.""" + + last_partition_id: int = Field(alias="last-partition-id") + """An integer; the highest assigned partition field ID across all + partition specs for the table. This is used to ensure partition fields + are always assigned an unused ID when evolving specs.""" + + properties: Dict[str, str] = Field(default_factory=dict) + """ A string to string map of table properties. This is used to + control settings that affect reading and writing and is not intended + to be used for arbitrary metadata. For example, commit.retry.num-retries + is used to control the number of commit retries.""" + + current_snapshot_id: Optional[int] = Field(alias="current-snapshot-id") + """ID of the current table snapshot.""" + + snapshots: list = Field(default_factory=list) + """A list of valid snapshots. Valid snapshots are snapshots for which + all data files exist in the file system. A data file must not be + deleted from the file system until the last snapshot in which it was + listed is garbage collected.""" + + snapshot_log: List[Dict[str, Any]] = Field(alias="snapshot-log", default_factory=list) + """A list (optional) of timestamp and snapshot ID pairs that encodes + changes to the current snapshot for the table. Each time the + current-snapshot-id is changed, a new entry should be added with the + last-updated-ms and the new current-snapshot-id. When snapshots are + expired from the list of valid snapshots, all entries before a snapshot + that has expired should be removed.""" + + metadata_log: List[Dict[str, Any]] = Field(alias="metadata-log", default_factory=list) + """A list (optional) of timestamp and metadata file location pairs that + encodes changes to the previous metadata files for the table. Each time + a new metadata file is created, a new entry of the previous metadata + file location should be added to the list. Tables can be configured to + remove oldest metadata log entries and keep a fixed-size log of the most + recent entries after a commit.""" + + sort_orders: List[Dict[str, Any]] = Field(alias="sort-orders", default_factory=list) + """A list of sort orders, stored as full sort order objects.""" + + default_sort_order_id: int = Field(alias="default-sort-order-id", default=DEFAULT_SORT_ORDER_UNSORTED) + """Default sort order id of the table. Note that this could be used by + writers, but is not used when reading because reads use the specs stored + in manifest files.""" + + refs: Dict[str, SnapshotRef] = Field(default_factory=dict) + """A map of snapshot references. + The map keys are the unique snapshot reference names in the table, + and the map values are snapshot reference objects. + There is always a main branch reference pointing to the + current-snapshot-id even if the refs map is null.""" + + +class TableMetadataV1(TableMetadataCommonFields, IcebergBaseModel): + """Represents version 1 of the Table Metadata + + More information about the specification: + https://iceberg.apache.org/spec/#version-1-analytic-data-tables + """ + + # When we read a V1 format-version, we'll make sure to populate the fields + # for V2 as well. This makes it easier downstream because we can just + # assume that everything is a TableMetadataV2. + # When writing, we should stick to the same version that it was, + # because bumping the version should be an explicit operation that is up + # to the owner of the table. + + @root_validator(pre=True) + def set_v2_compatible_defaults(cls, data: Dict[str, Any]) -> Dict[str, Any]: + """Sets default values to be compatible with the format v2 + + Set some sensible defaults for V1, so we comply with the schema + this is in pre=True, meaning that this will be done before validation. + We don't want to make the fields optional, since they are required for V2 + + Args: + data: The raw arguments when initializing a V1 TableMetadata + + Returns: + The TableMetadata with the defaults applied + """ + if "schema-id" not in data["schema"]: + data["schema"]["schema-id"] = DEFAULT_SCHEMA_ID + if "last-partition-id" not in data: + data["last-partition-id"] = max(spec["field-id"] for spec in data["partition-spec"]) + if "table-uuid" not in data: + data["table-uuid"] = uuid4() + return data + + @root_validator(skip_on_failure=True) + def construct_schemas(cls, data: Dict[str, Any]) -> Dict[str, Any]: + """Converts the schema into schemas + + For V1 schemas is optional, and if they aren't set, we'll set them + in this validator. This was we can always use the schemas when reading + table metadata, and we don't have to worry if it is a v1 or v2 format. + + Args: + data: The raw data after validation, meaning that the aliases are applied + + Returns: + The TableMetadata with the schemas set, if not provided + """ + if not data.get("schemas"): + schema = data["schema_"] + data["schemas"] = [schema] + else: + check_schemas(data["schemas"]) + return data + + @root_validator(skip_on_failure=True) + def construct_partition_specs(cls, data: Dict[str, Any]) -> Dict[str, Any]: + """Converts the partition_spec into partition_specs + + For V1 partition_specs is optional, and if they aren't set, we'll set them + in this validator. This was we can always use the partition_specs when reading + table metadata, and we don't have to worry if it is a v1 or v2 format. + + Args: + data: The raw data after validation, meaning that the aliases are applied + + Returns: + The TableMetadata with the partition_specs set, if not provided + """ + # This is going to be much nicer as soon as partition-spec is also migrated to pydantic + if not data.get("partition_specs"): + fields = data["partition_spec"] + data["partition_specs"] = [{"spec-id": INITIAL_SPEC_ID, "fields": fields}] + else: + check_partition_specs(data["partition_specs"]) + return data + + @root_validator(skip_on_failure=True) + def set_sort_orders(cls, data: Dict[str, Any]): + """Sets the sort_orders if not provided + + For V1 sort_orders is optional, and if they aren't set, we'll set them + in this validator. + + Args: + data: The raw data after validation, meaning that the aliases are applied + + Returns: + The TableMetadata with the sort_orders set, if not provided + """ + # This is going to be much nicer as soon as sort-order is an actual pydantic object + # Probably we'll just create a UNSORTED_ORDER constant then + if not data.get("sort_orders"): + data["sort_orders"] = [{"order_id": 0, "fields": []}] + else: + check_sort_orders(data["sort_orders"]) + return data + + def to_v2(self) -> "TableMetadataV2": + metadata = copy(self.dict()) + metadata["format_version"] = 2 + return TableMetadataV2(**metadata) + + format_version: Literal[1] = Field(alias="format-version") + """An integer version number for the format. Currently, this can be 1 or 2 + based on the spec. Implementations must throw an exception if a table’s + version is higher than the supported version.""" + + schema_: Schema = Field(alias="schema") + """The table’s current schema. (Deprecated: use schemas and + current-schema-id instead)""" + + partition_spec: List[Dict[str, Any]] = Field(alias="partition-spec") + """The table’s current partition spec, stored as only fields. + Note that this is used by writers to partition data, but is + not used when reading because reads use the specs stored in + manifest files. (Deprecated: use partition-specs and default-spec-id + instead)""" + + +class TableMetadataV2(TableMetadataCommonFields, IcebergBaseModel): + """Represents version 2 of the Table Metadata + + This extends Version 1 with row-level deletes, and adds some additional + information to the schema, such as all the historical schemas, partition-specs, + sort-orders. + + For more information: + https://iceberg.apache.org/spec/#version-2-row-level-deletes + """ + + @root_validator(skip_on_failure=True) + def check_schemas(cls, values: Dict[str, Any]): + return check_schemas(values) + + @root_validator + def check_partition_specs(cls, values: Dict[str, Any]): + return check_partition_specs(values) + + @root_validator(skip_on_failure=True) + def check_sort_orders(cls, values: Dict[str, Any]): + return check_sort_orders(values) + + format_version: Literal[2] = Field(alias="format-version") + """An integer version number for the format. Currently, this can be 1 or 2 + based on the spec. Implementations must throw an exception if a table’s + version is higher than the supported version.""" + + table_uuid: UUID = Field(alias="table-uuid") + """A UUID that identifies the table, generated when the table is created. + Implementations must throw an exception if a table’s UUID does not match + the expected UUID after refreshing metadata.""" + + last_sequence_number: int = Field(alias="last-sequence-number", default=_INITIAL_SEQUENCE_NUMBER) + """The table’s highest assigned sequence number, a monotonically + increasing long that tracks the order of snapshots in a table.""" + + +class TableMetadata: + """Helper class for parsing TableMetadata""" + + # Once this has been resolved, we can simplify this: https://github.com/samuelcolvin/pydantic/issues/3846 + # TableMetadata = Annotated[Union[TableMetadataV1, TableMetadataV2], Field(alias="format-version", discriminator="format-version")] + + @staticmethod + def parse_obj(data: dict) -> Union[TableMetadataV1, TableMetadataV2]: + if "format-version" not in data: + raise ValidationError(f"Missing format-version in TableMetadata: {data}") + + format_version = data["format-version"] + + if format_version == 1: + return TableMetadataV1(**data) + elif format_version == 2: + return TableMetadataV2(**data) + else: + raise ValidationError(f"Unknown format version: {format_version}") diff --git a/python/src/iceberg/table/refs.py b/python/src/iceberg/table/refs.py new file mode 100644 index 000000000000..285c5cfd93b6 --- /dev/null +++ b/python/src/iceberg/table/refs.py @@ -0,0 +1,37 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from enum import Enum +from typing import Optional + +from pydantic import Field + +from iceberg.utils.iceberg_base_model import IcebergBaseModel + +MAIN_BRANCH = "main" + + +class SnapshotRefType(str, Enum): + BRANCH = "branch" + TAG = "tag" + + +class SnapshotRef(IcebergBaseModel): + snapshot_id: int = Field(alias="snapshot-id") + snapshot_ref_type: SnapshotRefType = Field(alias="type") + min_snapshots_to_keep: Optional[int] = Field(alias="min-snapshots-to-keep", default=None) + max_snapshot_age_ms: Optional[int] = Field(alias="max-snapshot-age-ms", default=None) + max_ref_age_ms: Optional[int] = Field(alias="max-ref-age-ms", default=None) diff --git a/python/src/iceberg/types.py b/python/src/iceberg/types.py index d3ae67237380..ecfb65b10a96 100644 --- a/python/src/iceberg/types.py +++ b/python/src/iceberg/types.py @@ -29,16 +29,25 @@ Notes: - https://iceberg.apache.org/#spec/#primitive-types """ -from abc import ABC -from dataclasses import dataclass, field -from functools import cached_property -from typing import ClassVar, Optional, Tuple - +import re +from typing import ( + ClassVar, + Dict, + Literal, + Optional, + Tuple, +) + +from pydantic import Field, PrivateAttr + +from iceberg.utils.iceberg_base_model import IcebergBaseModel from iceberg.utils.singleton import Singleton +DECIMAL_REGEX = re.compile(r"decimal\((\d+),\s*(\d+)\)") +FIXED_REGEX = re.compile(r"fixed\[(\d+)\]") + -@dataclass(frozen=True) -class IcebergType(ABC, Singleton): +class IcebergType(IcebergBaseModel, Singleton): """Base type for all Iceberg Types Example: @@ -48,50 +57,92 @@ class IcebergType(ABC, Singleton): 'IcebergType()' """ - @property - def string_type(self) -> str: - return self.__repr__() - - def __str__(self) -> str: - return self.string_type + @classmethod + def __get_validators__(cls): + # one or more validators may be yielded which will be called in the + # order to validate the input, each validator will receive as an input + # the value returned from the previous validator + yield cls.validate + + @classmethod + def validate(cls, v): + # When Pydantic is unable to determine the subtype + # In this case we'll help pydantic a bit by parsing the + # primitive type ourselves, or pointing it at the correct + # complex type by looking at the type field + + if isinstance(v, str): + if v.startswith("decimal"): + return DecimalType.parse(v) + elif v.startswith("fixed"): + return FixedType.parse(v) + else: + return PRIMITIVE_TYPES[v] + elif isinstance(v, dict): + if v.get("type") == "struct": + return StructType(**v) + elif v.get("type") == "list": + return ListType(**v) + elif v.get("type") == "map": + return MapType(**v) + else: + return NestedField(**v) + else: + return v @property def is_primitive(self) -> bool: return isinstance(self, PrimitiveType) -@dataclass(frozen=True, eq=True) class PrimitiveType(IcebergType): - """Base class for all Iceberg Primitive Types + """Base class for all Iceberg Primitive Types""" - Example: - >>> str(PrimitiveType()) - 'PrimitiveType()' - """ + __root__: str = Field() + + def __repr__(self) -> str: + return f"{type(self).__name__}()" + + def __str__(self) -> str: + return self.__root__ -@dataclass(frozen=True) class FixedType(PrimitiveType): """A fixed data type in Iceberg. - Example: >>> FixedType(8) FixedType(length=8) >>> FixedType(8) == FixedType(8) True + >>> FixedType(19) == FixedType(25) + False """ - length: int = field() + __root__: str = Field() + _length: int = PrivateAttr() + + @staticmethod + def parse(str_repr: str) -> "FixedType": + matches = FIXED_REGEX.search(str_repr) + if matches: + length = int(matches.group(1)) + return FixedType(length) + raise ValueError(f"Could not parse {str_repr} into a FixedType") + + def __init__(self, length: int): + super().__init__(__root__=f"fixed[{length}]") + self._length = length @property - def string_type(self) -> str: - return f"fixed[{self.length}]" + def length(self) -> int: + return self._length + + def __repr__(self) -> str: + return f"FixedType(length={self._length})" -@dataclass(frozen=True, eq=True) class DecimalType(PrimitiveType): """A fixed data type in Iceberg. - Example: >>> DecimalType(32, 3) DecimalType(precision=32, scale=3) @@ -99,15 +150,40 @@ class DecimalType(PrimitiveType): True """ - precision: int = field() - scale: int = field() + __root__: str = Field() + + _precision: int = PrivateAttr() + _scale: int = PrivateAttr() + + @staticmethod + def parse(str_repr: str) -> "DecimalType": + matches = DECIMAL_REGEX.search(str_repr) + if matches: + precision = int(matches.group(1)) + scale = int(matches.group(2)) + return DecimalType(precision, scale) + else: + raise ValueError(f"Could not parse {str_repr} into a DecimalType") + + def __init__(self, precision: int, scale: int): + super().__init__( + __root__=f"decimal({precision}, {scale})", + ) + self._precision = precision + self._scale = scale + + @property + def precision(self) -> int: + return self._precision @property - def string_type(self) -> str: - return f"decimal({self.precision}, {self.scale})" + def scale(self) -> int: + return self._scale + + def __repr__(self) -> str: + return f"DecimalType(precision={self._precision}, scale={self._scale})" -@dataclass(frozen=True) class NestedField(IcebergType): """Represents a field of a struct, a map key, a map value, or a list element. @@ -120,35 +196,51 @@ class NestedField(IcebergType): ... field_type=FixedType(22), ... required=False, ... )) - '1: foo: required fixed[22]' + '1: foo: optional fixed[22]' >>> str(NestedField( ... field_id=2, ... name='bar', ... field_type=LongType(), - ... required=False, + ... is_optional=False, ... doc="Just a long" ... )) '2: bar: required long (Just a long)' """ - field_id: int = field() - name: str = field() - field_type: IcebergType = field() - required: bool = field(default=True) - doc: Optional[str] = field(default=None, repr=False) + field_id: int = Field(alias="id") + name: str = Field() + field_type: IcebergType = Field(alias="type") + required: bool = Field(default=True) + doc: Optional[str] = Field(default=None, repr=False) + + def __init__( + self, + field_id: Optional[int] = None, + name: Optional[str] = None, + field_type: Optional[IcebergType] = None, + required: bool = True, + doc: Optional[str] = None, + **data, + ): + # We need an init when we want to use positional arguments, but + # need also to support the aliases. + data["field_id"] = data["id"] if "id" in data else field_id + data["name"] = name + data["field_type"] = data["type"] if "type" in data else field_type + data["required"] = required + data["doc"] = doc + super().__init__(**data) + + def __str__(self) -> str: + doc = "" if not self.doc else f" ({self.doc})" + req = "required" if self.required else "optional" + return f"{self.field_id}: {self.name}: {req} {self.field_type}{doc}" @property def optional(self) -> bool: return not self.required - @property - def string_type(self) -> str: - doc = "" if not self.doc else f" ({self.doc})" - req = "optional" if self.required else "required" - return f"{self.field_id}: {self.name}: {req} {self.field_type}{doc}" - -@dataclass(frozen=True, init=False) class StructType(IcebergType): """A struct type in Iceberg @@ -160,19 +252,22 @@ class StructType(IcebergType): 'struct<1: required_field: optional string, 2: optional_field: optional int>' """ - fields: Tuple[NestedField] = field() + type: Literal["struct"] = "struct" + fields: Tuple[NestedField, ...] = Field(default_factory=tuple) - def __init__(self, *fields: NestedField, **kwargs): # pylint: disable=super-init-not-called - if not fields and "fields" in kwargs: - fields = kwargs["fields"] - object.__setattr__(self, "fields", tuple(fields)) + def __init__(self, *fields: NestedField, **data): + # In case we use positional arguments, instead of keyword args + if fields: + data["fields"] = fields + super().__init__(**data) - @cached_property - def string_type(self) -> str: + def __str__(self) -> str: return f"struct<{', '.join(map(str, self.fields))}>" + def __repr__(self) -> str: + return f"StructType(fields=({', '.join(map(repr, self.fields))},))" + -@dataclass(frozen=True) class ListType(IcebergType): """A list type in Iceberg @@ -181,29 +276,33 @@ class ListType(IcebergType): ListType(element_id=3, element_type=StringType(), element_required=True) """ - element_id: int = field() - element_type: IcebergType = field() - element_required: bool = field(default=True) - element: NestedField = field(init=False, repr=False) - - def __post_init__(self): - object.__setattr__( - self, - "element", - NestedField( - name="element", - required=self.element_required, - field_id=self.element_id, - field_type=self.element_type, - ), + class Config: + fields = {"element_field": {"exclude": True}} + + type: Literal["list"] = "list" + element_id: int = Field(alias="element-id") + element_type: IcebergType = Field(alias="element") + element_required: bool = Field(alias="element-required", default=True) + element_field: NestedField = Field(init=False, repr=False) + + def __init__( + self, element_id: Optional[int] = None, element: Optional[IcebergType] = None, element_required: bool = True, **data + ): + data["element_id"] = data["element-id"] if "element-id" in data else element_id + data["element_type"] = element or data["element_type"] + data["element_required"] = data["element-required"] if "element-required" in data else element_required + data["element_field"] = NestedField( + name="element", + required=data["element_required"], + field_id=data["element_id"], + field_type=data["element_type"], ) + super().__init__(**data) - @property - def string_type(self) -> str: + def __str__(self) -> str: return f"list<{self.element_type}>" -@dataclass(frozen=True) class MapType(IcebergType): """A map type in Iceberg @@ -212,29 +311,43 @@ class MapType(IcebergType): MapType(key_id=1, key_type=StringType(), value_id=2, value_type=IntegerType(), value_required=True) """ - key_id: int = field() - key_type: IcebergType = field() - value_id: int = field() - value_type: IcebergType = field() - value_required: bool = field(default=True) - key: NestedField = field(init=False, repr=False) - value: NestedField = field(init=False, repr=False) - - def __post_init__(self): - object.__setattr__(self, "key", NestedField(name="key", field_id=self.key_id, field_type=self.key_type, required=False)) - object.__setattr__( - self, - "value", - NestedField( - name="value", - field_id=self.value_id, - field_type=self.value_type, - required=self.value_required, - ), + type: Literal["map"] = "map" + key_id: int = Field(alias="key-id") + key_type: IcebergType = Field(alias="key") + value_id: int = Field(alias="value-id") + value_type: IcebergType = Field(alias="value") + value_required: bool = Field(alias="value-required", default=True) + key_field: NestedField = Field(init=False, repr=False) + value_field: NestedField = Field(init=False, repr=False) + + class Config: + fields = {"key_field": {"exclude": True}, "value_field": {"exclude": True}} + + def __init__( + self, + key_id: Optional[int] = None, + key_type: Optional[IcebergType] = None, + value_id: Optional[int] = None, + value_type: Optional[IcebergType] = None, + value_required: bool = True, + **data, + ): + data["key_id"] = key_id or data["key-id"] + data["key_type"] = key_type or data["key"] + data["value_id"] = value_id or data["value-id"] + data["value_type"] = value_type or data["value"] + data["value_required"] = value_required if value_required is not None else data["value_required"] + + data["key_field"] = NestedField(name="key", field_id=data["key_id"], field_type=data["key_type"], required=True) + data["value_field"] = NestedField( + name="value", field_id=data["value_id"], field_type=data["value_type"], required=data["value_required"] ) + super().__init__(**data) + + def __str__(self) -> str: + return f"map<{self.key_type}, {self.value_type}>" -@dataclass(frozen=True) class BooleanType(PrimitiveType): """A boolean data type in Iceberg can be represented using an instance of this class. @@ -246,12 +359,9 @@ class BooleanType(PrimitiveType): BooleanType() """ - @property - def string_type(self) -> str: - return "boolean" + __root__ = "boolean" -@dataclass(frozen=True) class IntegerType(PrimitiveType): """An Integer data type in Iceberg can be represented using an instance of this class. Integers in Iceberg are 32-bit signed and can be promoted to Longs. @@ -271,12 +381,9 @@ class IntegerType(PrimitiveType): max: ClassVar[int] = 2147483647 min: ClassVar[int] = -2147483648 - @property - def string_type(self) -> str: - return "int" + __root__ = "int" -@dataclass(frozen=True) class LongType(PrimitiveType): """A Long data type in Iceberg can be represented using an instance of this class. Longs in Iceberg are 64-bit signed integers. @@ -300,12 +407,9 @@ class LongType(PrimitiveType): max: ClassVar[int] = 9223372036854775807 min: ClassVar[int] = -9223372036854775808 - @property - def string_type(self) -> str: - return "long" + __root__ = "long" -@dataclass(frozen=True) class FloatType(PrimitiveType): """A Float data type in Iceberg can be represented using an instance of this class. Floats in Iceberg are 32-bit IEEE 754 floating points and can be promoted to Doubles. @@ -327,12 +431,9 @@ class FloatType(PrimitiveType): max: ClassVar[float] = 3.4028235e38 min: ClassVar[float] = -3.4028235e38 - @property - def string_type(self) -> str: - return "float" + __root__ = "float" -@dataclass(frozen=True) class DoubleType(PrimitiveType): """A Double data type in Iceberg can be represented using an instance of this class. Doubles in Iceberg are 64-bit IEEE 754 floating points. @@ -345,12 +446,9 @@ class DoubleType(PrimitiveType): DoubleType() """ - @property - def string_type(self) -> str: - return "double" + __root__ = "double" -@dataclass(frozen=True) class DateType(PrimitiveType): """A Date data type in Iceberg can be represented using an instance of this class. Dates in Iceberg are calendar dates without a timezone or time. @@ -363,12 +461,9 @@ class DateType(PrimitiveType): DateType() """ - @property - def string_type(self) -> str: - return "date" + __root__ = "date" -@dataclass(frozen=True) class TimeType(PrimitiveType): """A Time data type in Iceberg can be represented using an instance of this class. Times in Iceberg have microsecond precision and are a time of day without a date or timezone. @@ -381,12 +476,9 @@ class TimeType(PrimitiveType): TimeType() """ - @property - def string_type(self) -> str: - return "time" + __root__ = "time" -@dataclass(frozen=True) class TimestampType(PrimitiveType): """A Timestamp data type in Iceberg can be represented using an instance of this class. Timestamps in Iceberg have microsecond precision and include a date and a time of day without a timezone. @@ -399,12 +491,9 @@ class TimestampType(PrimitiveType): TimestampType() """ - @property - def string_type(self) -> str: - return "timestamp" + __root__ = "timestamp" -@dataclass(frozen=True) class TimestamptzType(PrimitiveType): """A Timestamptz data type in Iceberg can be represented using an instance of this class. Timestamptzs in Iceberg are stored as UTC and include a date and a time of day with a timezone. @@ -417,12 +506,9 @@ class TimestamptzType(PrimitiveType): TimestamptzType() """ - @property - def string_type(self) -> str: - return "timestamptz" + __root__ = "timestamptz" -@dataclass(frozen=True) class StringType(PrimitiveType): """A String data type in Iceberg can be represented using an instance of this class. Strings in Iceberg are arbitrary-length character sequences and are encoded with UTF-8. @@ -435,12 +521,9 @@ class StringType(PrimitiveType): StringType() """ - @property - def string_type(self) -> str: - return "string" + __root__ = "string" -@dataclass(frozen=True) class UUIDType(PrimitiveType): """A UUID data type in Iceberg can be represented using an instance of this class. UUIDs in Iceberg are universally unique identifiers. @@ -453,12 +536,9 @@ class UUIDType(PrimitiveType): UUIDType() """ - @property - def string_type(self) -> str: - return "uuid" + __root__ = "uuid" -@dataclass(frozen=True) class BinaryType(PrimitiveType): """A Binary data type in Iceberg can be represented using an instance of this class. Binaries in Iceberg are arbitrary-length byte arrays. @@ -471,6 +551,20 @@ class BinaryType(PrimitiveType): BinaryType() """ - @property - def string_type(self) -> str: - return "binary" + __root__ = "binary" + + +PRIMITIVE_TYPES: Dict[str, PrimitiveType] = { + "boolean": BooleanType(), + "int": IntegerType(), + "long": LongType(), + "float": FloatType(), + "double": DoubleType(), + "date": DateType(), + "time": TimeType(), + "timestamp": TimestampType(), + "timestamptz": TimestamptzType(), + "string": StringType(), + "uuid": UUIDType(), + "binary": BinaryType(), +} diff --git a/python/src/iceberg/utils/iceberg_base_model.py b/python/src/iceberg/utils/iceberg_base_model.py new file mode 100644 index 000000000000..e218f9f89b5f --- /dev/null +++ b/python/src/iceberg/utils/iceberg_base_model.py @@ -0,0 +1,52 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from functools import cached_property + +from pydantic import BaseModel + + +class IcebergBaseModel(BaseModel): + """ + This class extends the Pydantic BaseModel to set default values by overriding them. + + This is because we always want to set by_alias to True. In Python, the dash can't + be used in variable names, and this is used throughout the Iceberg spec. + + The same goes for exclude_none, if a field is None we want to omit it from + serialization, for example, the doc attribute on the NestedField object. + Default non-null values will be serialized. + + This is recommended by Pydantic: + https://pydantic-docs.helpmanual.io/usage/model_config/#change-behaviour-globally + """ + + class Config: + keep_untouched = (cached_property,) + allow_population_by_field_name = True + frozen = True + + def dict(self, exclude_none: bool = True, **kwargs): + return super().dict(exclude_none=exclude_none, **kwargs) + + def json(self, exclude_none: bool = True, by_alias: bool = True, **kwargs): + # A small trick to exclude private properties. Properties are serialized by pydantic, + # regardless if they start with an underscore. + # This will look at the dict, and find the fields and exclude them + exclude = set.union( + {field for field in self.__dict__ if field.startswith("_") and not field == "__root__"}, kwargs.get("exclude", set()) + ) + return super().json(exclude_none=exclude_none, exclude=exclude, by_alias=by_alias, **kwargs) diff --git a/python/src/iceberg/utils/singleton.py b/python/src/iceberg/utils/singleton.py index 5643cdd1728f..4dd48d58671b 100644 --- a/python/src/iceberg/utils/singleton.py +++ b/python/src/iceberg/utils/singleton.py @@ -26,16 +26,23 @@ return it. More information on metaclasses: https://docs.python.org/3/reference/datamodel.html#metaclasses - """ -from typing import ClassVar, Dict +from typing import Any, ClassVar, Dict + + +def _convert_to_hashable_type(element: Any) -> Any: + if isinstance(element, dict): + return tuple((_convert_to_hashable_type(k), _convert_to_hashable_type(v)) for k, v in element.items()) + elif isinstance(element, list): + return tuple(map(_convert_to_hashable_type, element)) + return element class Singleton: _instances: ClassVar[Dict] = {} def __new__(cls, *args, **kwargs): - key = (cls, args, tuple(sorted(kwargs.items()))) + key = (cls, tuple(args), _convert_to_hashable_type(kwargs)) if key not in cls._instances: cls._instances[key] = super().__new__(cls) return cls._instances[key] diff --git a/python/tests/avro/test_reader.py b/python/tests/avro/test_reader.py index e8c56f122e84..6ea8377e1104 100644 --- a/python/tests/avro/test_reader.py +++ b/python/tests/avro/test_reader.py @@ -448,7 +448,7 @@ def test_binary_reader(): def test_unknown_type(): class UnknownType(PrimitiveType): - ... + __root__ = "UnknownType" with pytest.raises(ValueError) as exc_info: primitive_reader(UnknownType()) diff --git a/python/tests/conftest.py b/python/tests/conftest.py index 782316ec9192..d211fd13dc82 100644 --- a/python/tests/conftest.py +++ b/python/tests/conftest.py @@ -15,17 +15,34 @@ # specific language governing permissions and limitations # under the License. # pylint:disable=redefined-outer-name +"""This contains global pytest configurations. +Fixtures contained in this file will be automatically used if provided as an argument +to any pytest function. + +In the case where the fixture must be used in a pytest.mark.parametrize decorator, the string representation can be used +and the built-in pytest fixture request should be used as an additional argument in the function. The fixture can then be +retrieved using `request.getfixturevalue(fixture_name)`. +""" +import os from tempfile import TemporaryDirectory -from typing import Any, Dict +from typing import Any, Dict, Union +from urllib.parse import urlparse import pytest from iceberg import schema +from iceberg.io.base import ( + FileIO, + InputFile, + OutputFile, + OutputStream, +) from iceberg.schema import Schema from iceberg.types import ( BinaryType, BooleanType, + DoubleType, FloatType, IntegerType, ListType, @@ -36,6 +53,7 @@ StructType, ) from tests.catalog.test_base import InMemoryCatalog +from tests.io.test_io_base import LocalInputFile class FooStruct: @@ -768,6 +786,72 @@ def avro_schema_manifest_entry() -> Dict[str, Any]: } +@pytest.fixture(scope="session") +def simple_struct(): + return StructType( + NestedField(id=1, name="required_field", field_type=StringType(), required=True, doc="this is a doc"), + NestedField(id=2, name="optional_field", field_type=IntegerType()), + ) + + +@pytest.fixture(scope="session") +def simple_list(): + return ListType(element_id=22, element=StringType(), element_required=True) + + +@pytest.fixture(scope="session") +def simple_map(): + return MapType(key_id=19, key_type=StringType(), value_id=25, value_type=DoubleType(), value_required=False) + + +class LocalOutputFile(OutputFile): + """An OutputFile implementation for local files (for test use only)""" + + def __init__(self, location: str): + parsed_location = urlparse(location) # Create a ParseResult from the uri + if parsed_location.scheme and parsed_location.scheme != "file": # Validate that a uri is provided with a scheme of `file` + raise ValueError("LocalOutputFile location must have a scheme of `file`") + elif parsed_location.netloc: + raise ValueError(f"Network location is not allowed for LocalOutputFile: {parsed_location.netloc}") + + super().__init__(location=location) + self._path = parsed_location.path + + def __len__(self): + return os.path.getsize(self._path) + + def exists(self): + return os.path.exists(self._path) + + def to_input_file(self): + return LocalInputFile(location=self.location) + + def create(self, overwrite: bool = False) -> OutputStream: + output_file = open(self._path, "wb" if overwrite else "xb") + if not isinstance(output_file, OutputStream): + raise TypeError("Object returned from LocalOutputFile.create(...) does not match the OutputStream protocol.") + return output_file + + +class LocalFileIO(FileIO): + """A FileIO implementation for local files (for test use only)""" + + def new_input(self, location: str): + return LocalInputFile(location=location) + + def new_output(self, location: str): + return LocalOutputFile(location=location) + + def delete(self, location: Union[str, InputFile, OutputFile]): + location = location.location if isinstance(location, (InputFile, OutputFile)) else location + os.remove(location) + + +@pytest.fixture(scope="session", autouse=True) +def LocalFileIOFixture(): + return LocalFileIO + + @pytest.fixture(scope="session") def generated_manifest_entry_file(avro_schema_manifest_entry): from fastavro import parse_schema, writer diff --git a/python/tests/table/test_metadata.py b/python/tests/table/test_metadata.py new file mode 100644 index 000000000000..7e0613b3dbac --- /dev/null +++ b/python/tests/table/test_metadata.py @@ -0,0 +1,522 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import io +import json +from uuid import UUID + +import pytest + +from iceberg.exceptions import ValidationError +from iceberg.schema import Schema +from iceberg.serializers import FromByteStream +from iceberg.table.metadata import TableMetadata, TableMetadataV1, TableMetadataV2 +from iceberg.table.refs import SnapshotRef, SnapshotRefType +from iceberg.types import LongType, NestedField + +EXAMPLE_TABLE_METADATA_V1 = { + "format-version": 1, + "table-uuid": "d20125c8-7284-442c-9aea-15fee620737c", + "location": "s3://bucket/test/location", + "last-updated-ms": 1602638573874, + "last-column-id": 3, + "schema": { + "type": "struct", + "fields": [ + {"id": 1, "name": "x", "required": True, "type": "long"}, + {"id": 2, "name": "y", "required": True, "type": "long", "doc": "comment"}, + {"id": 3, "name": "z", "required": True, "type": "long"}, + ], + }, + "partition-spec": [{"name": "x", "transform": "identity", "source-id": 1, "field-id": 1000}], + "properties": {}, + "current-snapshot-id": -1, + "snapshots": [{"snapshot-id": 1925, "timestamp-ms": 1602638573822}], +} +EXAMPLE_TABLE_METADATA_V2 = { + "format-version": 2, + "table-uuid": "9c12d441-03fe-4693-9a96-a0705ddf69c1", + "location": "s3://bucket/test/location", + "last-sequence-number": 34, + "last-updated-ms": 1602638573590, + "last-column-id": 3, + "current-schema-id": 1, + "schemas": [ + {"type": "struct", "schema-id": 0, "fields": [{"id": 1, "name": "x", "required": True, "type": "long"}]}, + { + "type": "struct", + "schema-id": 1, + "identifier-field-ids": [1, 2], + "fields": [ + {"id": 1, "name": "x", "required": True, "type": "long"}, + {"id": 2, "name": "y", "required": True, "type": "long", "doc": "comment"}, + {"id": 3, "name": "z", "required": True, "type": "long"}, + ], + }, + ], + "default-spec-id": 0, + "partition-specs": [{"spec-id": 0, "fields": [{"name": "x", "transform": "identity", "source-id": 1, "field-id": 1000}]}], + "last-partition-id": 1000, + "default-sort-order-id": 3, + "sort-orders": [ + { + "order-id": 3, + "fields": [ + {"transform": "identity", "source-id": 2, "direction": "asc", "null-order": "nulls-first"}, + {"transform": "bucket[4]", "source-id": 3, "direction": "desc", "null-order": "nulls-last"}, + ], + } + ], + "properties": {"read.split.target.size": 134217728}, + "current-snapshot-id": 3055729675574597004, + "snapshots": [ + { + "snapshot-id": 3051729675574597004, + "timestamp-ms": 1515100955770, + "sequence-number": 0, + "summary": {"operation": "append"}, + "manifest-list": "s3://a/b/1.avro", + }, + { + "snapshot-id": 3055729675574597004, + "parent-snapshot-id": 3051729675574597004, + "timestamp-ms": 1555100955770, + "sequence-number": 1, + "summary": {"operation": "append"}, + "manifest-list": "s3://a/b/2.avro", + "schema-id": 1, + }, + ], + "snapshot-log": [ + {"snapshot-id": 3051729675574597004, "timestamp-ms": 1515100955770}, + {"snapshot-id": 3055729675574597004, "timestamp-ms": 1555100955770}, + ], + "metadata-log": [], +} + + +@pytest.mark.parametrize( + "metadata", + [ + EXAMPLE_TABLE_METADATA_V1, + EXAMPLE_TABLE_METADATA_V2, + ], +) +def test_from_dict(metadata: dict): + """Test initialization of a TableMetadata instance from a dictionary""" + TableMetadata.parse_obj(metadata) + + +def test_from_byte_stream(): + """Test generating a TableMetadata instance from a file-like byte stream""" + data = bytes(json.dumps(EXAMPLE_TABLE_METADATA_V2), encoding="utf-8") + byte_stream = io.BytesIO(data) + FromByteStream.table_metadata(byte_stream=byte_stream) + + +def test_v2_metadata_parsing(): + """Test retrieving values from a TableMetadata instance of version 2""" + table_metadata = TableMetadata.parse_obj(EXAMPLE_TABLE_METADATA_V2) + + assert table_metadata.format_version == 2 + assert table_metadata.table_uuid == UUID("9c12d441-03fe-4693-9a96-a0705ddf69c1") + assert table_metadata.location == "s3://bucket/test/location" + assert table_metadata.last_sequence_number == 34 + assert table_metadata.last_updated_ms == 1602638573590 + assert table_metadata.last_column_id == 3 + assert table_metadata.schemas[0].schema_id == 0 + assert table_metadata.current_schema_id == 1 + assert table_metadata.partition_specs[0]["spec-id"] == 0 + assert table_metadata.default_spec_id == 0 + assert table_metadata.last_partition_id == 1000 + assert table_metadata.properties["read.split.target.size"] == "134217728" + assert table_metadata.current_snapshot_id == 3055729675574597004 + assert table_metadata.snapshots[0]["snapshot-id"] == 3051729675574597004 + assert table_metadata.snapshot_log[0]["timestamp-ms"] == 1515100955770 + assert table_metadata.sort_orders[0]["order-id"] == 3 + assert table_metadata.default_sort_order_id == 3 + + +def test_v1_metadata_parsing_directly(): + """Test retrieving values from a TableMetadata instance of version 1""" + table_metadata = TableMetadataV1(**EXAMPLE_TABLE_METADATA_V1) + + assert isinstance(table_metadata, TableMetadataV1) + + # The version 1 will automatically be bumped to version 2 + assert table_metadata.format_version == 1 + assert table_metadata.table_uuid == UUID("d20125c8-7284-442c-9aea-15fee620737c") + assert table_metadata.location == "s3://bucket/test/location" + assert table_metadata.last_updated_ms == 1602638573874 + assert table_metadata.last_column_id == 3 + assert table_metadata.schemas[0].schema_id == 0 + assert table_metadata.current_schema_id == 0 + assert table_metadata.default_spec_id == 0 + assert table_metadata.last_partition_id == 1000 + assert table_metadata.current_snapshot_id is None + assert table_metadata.default_sort_order_id == 0 + + +def test_parsing_correct_types(): + table_metadata = TableMetadataV2(**EXAMPLE_TABLE_METADATA_V2) + assert isinstance(table_metadata.schemas[0], Schema) + assert isinstance(table_metadata.schemas[0].fields[0], NestedField) + assert isinstance(table_metadata.schemas[0].fields[0].field_type, LongType) + + +def test_updating_metadata(): + """Test creating a new TableMetadata instance that's an updated version of + an existing TableMetadata instance""" + table_metadata = TableMetadataV2(**EXAMPLE_TABLE_METADATA_V2) + + new_schema = { + "type": "struct", + "schema-id": 1, + "fields": [ + {"id": 1, "name": "foo", "required": True, "type": "string"}, + {"id": 2, "name": "bar", "required": True, "type": "string"}, + {"id": 3, "name": "baz", "required": True, "type": "string"}, + ], + } + + mutable_table_metadata = table_metadata.dict() + mutable_table_metadata["schemas"].append(new_schema) + mutable_table_metadata["current-schema-id"] = 1 + + new_table_metadata = TableMetadataV2(**mutable_table_metadata) + + assert new_table_metadata.current_schema_id == 1 + assert new_table_metadata.schemas[-1] == Schema(**new_schema) + + +def test_serialize_v1(): + table_metadata = TableMetadataV1(**EXAMPLE_TABLE_METADATA_V1).json() + assert ( + table_metadata + == """{"location": "s3://bucket/test/location", "table-uuid": "d20125c8-7284-442c-9aea-15fee620737c", "last-updated-ms": 1602638573874, "last-column-id": 3, "schemas": [{"fields": [{"id": 1, "name": "x", "type": "long", "required": true}, {"id": 2, "name": "y", "type": "long", "required": true, "doc": "comment"}, {"id": 3, "name": "z", "type": "long", "required": true}], "schema-id": 0, "identifier-field-ids": []}], "current-schema-id": 0, "partition-specs": [{"spec-id": 0, "fields": [{"name": "x", "transform": "identity", "source-id": 1, "field-id": 1000}]}], "default-spec-id": 0, "last-partition-id": 1000, "properties": {}, "snapshots": [{"snapshot-id": 1925, "timestamp-ms": 1602638573822}], "snapshot-log": [], "metadata-log": [], "sort-orders": [{"order_id": 0, "fields": []}], "default-sort-order-id": 0, "refs": {}, "format-version": 1, "schema": {"fields": [{"id": 1, "name": "x", "type": "long", "required": true}, {"id": 2, "name": "y", "type": "long", "required": true, "doc": "comment"}, {"id": 3, "name": "z", "type": "long", "required": true}], "schema-id": 0, "identifier-field-ids": []}, "partition-spec": [{"name": "x", "transform": "identity", "source-id": 1, "field-id": 1000}]}""" + ) + + +def test_serialize_v2(): + table_metadata = TableMetadataV2(**EXAMPLE_TABLE_METADATA_V2).json() + assert ( + table_metadata + == """{"location": "s3://bucket/test/location", "table-uuid": "9c12d441-03fe-4693-9a96-a0705ddf69c1", "last-updated-ms": 1602638573590, "last-column-id": 3, "schemas": [{"fields": [{"id": 1, "name": "x", "type": "long", "required": true}], "schema-id": 0, "identifier-field-ids": []}, {"fields": [{"id": 1, "name": "x", "type": "long", "required": true}, {"id": 2, "name": "y", "type": "long", "required": true, "doc": "comment"}, {"id": 3, "name": "z", "type": "long", "required": true}], "schema-id": 1, "identifier-field-ids": [1, 2]}], "current-schema-id": 1, "partition-specs": [{"spec-id": 0, "fields": [{"name": "x", "transform": "identity", "source-id": 1, "field-id": 1000}]}], "default-spec-id": 0, "last-partition-id": 1000, "properties": {"read.split.target.size": "134217728"}, "current-snapshot-id": 3055729675574597004, "snapshots": [{"snapshot-id": 3051729675574597004, "timestamp-ms": 1515100955770, "sequence-number": 0, "summary": {"operation": "append"}, "manifest-list": "s3://a/b/1.avro"}, {"snapshot-id": 3055729675574597004, "parent-snapshot-id": 3051729675574597004, "timestamp-ms": 1555100955770, "sequence-number": 1, "summary": {"operation": "append"}, "manifest-list": "s3://a/b/2.avro", "schema-id": 1}], "snapshot-log": [{"snapshot-id": 3051729675574597004, "timestamp-ms": 1515100955770}, {"snapshot-id": 3055729675574597004, "timestamp-ms": 1555100955770}], "metadata-log": [], "sort-orders": [{"order-id": 3, "fields": [{"transform": "identity", "source-id": 2, "direction": "asc", "null-order": "nulls-first"}, {"transform": "bucket[4]", "source-id": 3, "direction": "desc", "null-order": "nulls-last"}]}], "default-sort-order-id": 3, "refs": {"main": {"snapshot-id": 3055729675574597004, "type": "branch"}}, "format-version": 2, "last-sequence-number": 34}""" + ) + + +def test_migrate_v1_schemas(): + table_metadata = TableMetadataV1(**EXAMPLE_TABLE_METADATA_V1) + + assert isinstance(table_metadata, TableMetadataV1) + assert len(table_metadata.schemas) == 1 + assert table_metadata.schemas[0] == table_metadata.schema_ + + +def test_migrate_v1_partition_specs(): + # Copy the example, and add a spec + table_metadata = TableMetadataV1(**EXAMPLE_TABLE_METADATA_V1) + assert isinstance(table_metadata, TableMetadataV1) + assert len(table_metadata.partition_specs) == 1 + # Spec ID gets added automatically + assert table_metadata.partition_specs == [ + {"spec-id": 0, "fields": [{"field-id": 1000, "name": "x", "source-id": 1, "transform": "identity"}]} + ] + + +def test_invalid_format_version(): + """Test the exception when trying to load an unknown version""" + table_metadata_invalid_format_version = { + "format-version": -1, + "table-uuid": "d20125c8-7284-442c-9aea-15fee620737c", + "location": "s3://bucket/test/location", + "last-updated-ms": 1602638573874, + "last-column-id": 3, + "schema": { + "type": "struct", + "fields": [ + {"id": 1, "name": "x", "required": True, "type": "long"}, + {"id": 2, "name": "y", "required": True, "type": "long", "doc": "comment"}, + {"id": 3, "name": "z", "required": True, "type": "long"}, + ], + }, + "partition-spec": [{"name": "x", "transform": "identity", "source-id": 1, "field-id": 1000}], + "properties": {}, + "current-snapshot-id": -1, + "snapshots": [], + } + + with pytest.raises(ValidationError) as exc_info: + TableMetadata.parse_obj(table_metadata_invalid_format_version) + + assert "Unknown format version: -1" in str(exc_info.value) + + +def test_current_schema_not_found(): + """Test that we raise an exception when the schema can't be found""" + + table_metadata_schema_not_found = { + "format-version": 2, + "table-uuid": "d20125c8-7284-442c-9aea-15fee620737c", + "location": "s3://bucket/test/location", + "last-updated-ms": 1602638573874, + "last-column-id": 3, + "schemas": [ + {"type": "struct", "schema-id": 0, "fields": [{"id": 1, "name": "x", "required": True, "type": "long"}]}, + { + "type": "struct", + "schema-id": 1, + "identifier-field-ids": [1, 2], + "fields": [ + {"id": 1, "name": "x", "required": True, "type": "long"}, + {"id": 2, "name": "y", "required": True, "type": "long", "doc": "comment"}, + {"id": 3, "name": "z", "required": True, "type": "long"}, + ], + }, + ], + "current-schema-id": 2, + "default-spec-id": 0, + "partition-specs": [{"spec-id": 0, "fields": [{"name": "x", "transform": "identity", "source-id": 1, "field-id": 1000}]}], + "last-partition-id": 1000, + "default-sort-order-id": 0, + "properties": {}, + "current-snapshot-id": -1, + "snapshots": [], + } + + with pytest.raises(ValidationError) as exc_info: + TableMetadata.parse_obj(table_metadata_schema_not_found) + + assert "current-schema-id 2 can't be found in the schemas" in str(exc_info.value) + + +def test_sort_order_not_found(): + """Test that we raise an exception when the schema can't be found""" + + table_metadata_schema_not_found = { + "format-version": 2, + "table-uuid": "d20125c8-7284-442c-9aea-15fee620737c", + "location": "s3://bucket/test/location", + "last-updated-ms": 1602638573874, + "last-column-id": 3, + "schemas": [ + { + "type": "struct", + "schema-id": 0, + "identifier-field-ids": [1, 2], + "fields": [ + {"id": 1, "name": "x", "required": True, "type": "long"}, + {"id": 2, "name": "y", "required": True, "type": "long", "doc": "comment"}, + {"id": 3, "name": "z", "required": True, "type": "long"}, + ], + }, + ], + "default-sort-order-id": 4, + "sort-orders": [ + { + "order-id": 3, + "fields": [ + {"transform": "identity", "source-id": 2, "direction": "asc", "null-order": "nulls-first"}, + {"transform": "bucket[4]", "source-id": 3, "direction": "desc", "null-order": "nulls-last"}, + ], + } + ], + "current-schema-id": 0, + "default-spec-id": 0, + "partition-specs": [{"spec-id": 0, "fields": [{"name": "x", "transform": "identity", "source-id": 1, "field-id": 1000}]}], + "last-partition-id": 1000, + "properties": {}, + "current-snapshot-id": -1, + "snapshots": [], + } + + with pytest.raises(ValidationError) as exc_info: + TableMetadata.parse_obj(table_metadata_schema_not_found) + + assert "default-sort-order-id 4 can't be found" in str(exc_info.value) + + +def test_sort_order_unsorted(): + """Test that we raise an exception when the schema can't be found""" + + table_metadata_schema_not_found = { + "format-version": 2, + "table-uuid": "d20125c8-7284-442c-9aea-15fee620737c", + "location": "s3://bucket/test/location", + "last-updated-ms": 1602638573874, + "last-column-id": 3, + "schemas": [ + { + "type": "struct", + "schema-id": 0, + "identifier-field-ids": [1, 2], + "fields": [ + {"id": 1, "name": "x", "required": True, "type": "long"}, + {"id": 2, "name": "y", "required": True, "type": "long", "doc": "comment"}, + {"id": 3, "name": "z", "required": True, "type": "long"}, + ], + }, + ], + "default-sort-order-id": 0, + "sort-orders": [], + "current-schema-id": 0, + "default-spec-id": 0, + "partition-specs": [{"spec-id": 0, "fields": [{"name": "x", "transform": "identity", "source-id": 1, "field-id": 1000}]}], + "last-partition-id": 1000, + "properties": {}, + "current-snapshot-id": -1, + "snapshots": [], + } + + table_metadata = TableMetadata.parse_obj(table_metadata_schema_not_found) + + # Most important here is that we correctly handle sort-order-id 0 + assert len(table_metadata.sort_orders) == 0 + + +def test_invalid_partition_spec(): + table_metadata_spec_not_found = { + "format-version": 2, + "table-uuid": "9c12d441-03fe-4693-9a96-a0705ddf69c1", + "location": "s3://bucket/test/location", + "last-sequence-number": 34, + "last-updated-ms": 1602638573590, + "last-column-id": 3, + "current-schema-id": 1, + "schemas": [ + {"type": "struct", "schema-id": 0, "fields": [{"id": 1, "name": "x", "required": True, "type": "long"}]}, + { + "type": "struct", + "schema-id": 1, + "identifier-field-ids": [1, 2], + "fields": [ + {"id": 1, "name": "x", "required": True, "type": "long"}, + {"id": 2, "name": "y", "required": True, "type": "long", "doc": "comment"}, + {"id": 3, "name": "z", "required": True, "type": "long"}, + ], + }, + ], + "sort-orders": [], + "default-sort-order-id": 0, + "default-spec-id": 1, + "partition-specs": [{"spec-id": 0, "fields": [{"name": "x", "transform": "identity", "source-id": 1, "field-id": 1000}]}], + "last-partition-id": 1000, + } + with pytest.raises(ValidationError) as exc_info: + TableMetadata.parse_obj(table_metadata_spec_not_found) + + assert "default-spec-id 1 can't be found" in str(exc_info.value) + + +def test_v1_writing_metadata(): + """ + https://iceberg.apache.org/spec/#version-2 + + Writing v1 metadata: + - Table metadata field last-sequence-number should not be written + """ + + table_metadata = TableMetadataV1(**EXAMPLE_TABLE_METADATA_V1) + metadata_v1_json = table_metadata.json() + metadata_v1 = json.loads(metadata_v1_json) + + assert "last-sequence-number" not in metadata_v1 + + +def test_v1_metadata_for_v2(): + """ + https://iceberg.apache.org/spec/#version-2 + + Reading v1 metadata for v2: + - Table metadata field last-sequence-number must default to 0 + """ + + table_metadata = TableMetadataV1(**EXAMPLE_TABLE_METADATA_V1).to_v2() + + assert table_metadata.last_sequence_number == 0 + + +def test_v1_write_metadata_for_v2(): + """ + https://iceberg.apache.org/spec/#version-2 + + Table metadata JSON: + - last-sequence-number was added and is required; default to 0 when reading v1 metadata + - table-uuid is now required + - current-schema-id is now required + - schemas is now required + - partition-specs is now required + - default-spec-id is now required + - last-partition-id is now required + - sort-orders is now required + - default-sort-order-id is now required + - schema is no longer required and should be omitted; use schemas and current-schema-id instead + - partition-spec is no longer required and should be omitted; use partition-specs and default-spec-id instead + """ + + minimal_example_v1 = { + "format-version": 1, + "location": "s3://bucket/test/location", + "last-updated-ms": 1602638573874, + "last-column-id": 3, + "schema": { + "type": "struct", + "fields": [ + {"id": 1, "name": "x", "required": True, "type": "long"}, + {"id": 2, "name": "y", "required": True, "type": "long", "doc": "comment"}, + {"id": 3, "name": "z", "required": True, "type": "long"}, + ], + }, + "partition-spec": [{"name": "x", "transform": "identity", "source-id": 1, "field-id": 1000}], + "properties": {}, + "current-snapshot-id": -1, + "snapshots": [{"snapshot-id": 1925, "timestamp-ms": 1602638573822}], + } + + table_metadata = TableMetadataV1(**minimal_example_v1).to_v2() + metadata_v2_json = table_metadata.json() + metadata_v2 = json.loads(metadata_v2_json) + + assert metadata_v2["last-sequence-number"] == 0 + assert UUID(metadata_v2["table-uuid"]) is not None + assert metadata_v2["current-schema-id"] == 0 + assert metadata_v2["schemas"] == [ + { + "fields": [ + {"id": 1, "name": "x", "required": True, "type": "long"}, + {"doc": "comment", "id": 2, "name": "y", "required": True, "type": "long"}, + {"id": 3, "name": "z", "required": True, "type": "long"}, + ], + "identifier-field-ids": [], + "schema-id": 0, + } + ] + assert metadata_v2["partition-specs"] == [ + {"spec-id": 0, "fields": [{"name": "x", "transform": "identity", "source-id": 1, "field-id": 1000}]} + ] + assert metadata_v2["default-spec-id"] == 0 + assert metadata_v2["last-partition-id"] == 1000 + assert metadata_v2["sort-orders"] == [{"fields": [], "order_id": 0}] + assert metadata_v2["default-sort-order-id"] == 0 + # Deprecated fields + assert "schema" not in metadata_v2 + assert "partition-spec" not in metadata_v2 + + +def test_v2_ref_creation(): + table_metadata = TableMetadataV2(**EXAMPLE_TABLE_METADATA_V2) + assert table_metadata.refs == {"main": SnapshotRef(snapshot_id=3055729675574597004, snapshot_ref_type=SnapshotRefType.BRANCH)} diff --git a/python/tests/test_schema.py b/python/tests/test_schema.py index 81ebcf355693..a4f93c18177d 100644 --- a/python/tests/test_schema.py +++ b/python/tests/test_schema.py @@ -23,7 +23,7 @@ from iceberg import schema from iceberg.expressions.base import Accessor from iceberg.files import StructProtocol -from iceberg.schema import build_position_accessors +from iceberg.schema import Schema, build_position_accessors from iceberg.types import ( BooleanType, FloatType, @@ -36,14 +36,14 @@ ) -def test_schema_str(table_schema_simple): +def test_schema_str(table_schema_simple: Schema): """Test casting a schema to a string""" assert str(table_schema_simple) == dedent( """\ table { - 1: foo: required string - 2: bar: optional int - 3: baz: required boolean + 1: foo: optional string + 2: bar: required int + 3: baz: optional boolean }""" ) @@ -61,7 +61,7 @@ def test_schema_str(table_schema_simple): ), ], ) -def test_schema_repr(schema_repr, expected_repr): +def test_schema_repr(schema_repr: Schema, expected_repr: str): """Test schema representation""" assert repr(schema_repr) == expected_repr @@ -107,8 +107,8 @@ def test_schema_index_by_id_visitor(table_schema_nested): ), required=True, ), - 7: NestedField(field_id=7, name="key", field_type=StringType(), required=False), - 9: NestedField(field_id=9, name="key", field_type=StringType(), required=False), + 7: NestedField(field_id=7, name="key", field_type=StringType(), required=True), + 9: NestedField(field_id=9, name="key", field_type=StringType(), required=True), 8: NestedField( field_id=8, name="value", @@ -274,14 +274,14 @@ def test_index_by_id_schema_visitor(table_schema_nested): ), required=True, ), - 7: NestedField(field_id=7, name="key", field_type=StringType(), required=False), + 7: NestedField(field_id=7, name="key", field_type=StringType(), required=True), 8: NestedField( field_id=8, name="value", field_type=MapType(key_id=9, key_type=StringType(), value_id=10, value_type=IntegerType(), value_required=True), required=True, ), - 9: NestedField(field_id=9, name="key", field_type=StringType(), required=False), + 9: NestedField(field_id=9, name="key", field_type=StringType(), required=True), 10: NestedField(field_id=10, name="value", field_type=IntegerType(), required=True), 11: NestedField( field_id=11, @@ -386,7 +386,7 @@ def test_build_position_accessors(table_schema_nested): } -def test_build_position_accessors_with_struct(table_schema_nested): +def test_build_position_accessors_with_struct(table_schema_nested: Schema): class TestStruct(StructProtocol): def __init__(self, pos: Dict[int, Any] = None): self._pos: Dict[int, Any] = pos or {} @@ -399,4 +399,20 @@ def get(self, pos: int) -> Any: accessors = build_position_accessors(table_schema_nested) container = TestStruct({6: TestStruct({0: "name"})}) - assert accessors.get(16).get(container) == "name" + inner_accessor = accessors.get(16) + assert inner_accessor + assert inner_accessor.get(container) == "name" + + +def test_serialize_schema(table_schema_simple: Schema): + actual = table_schema_simple.json() + expected = """{"fields": [{"id": 1, "name": "foo", "type": "string", "required": false}, {"id": 2, "name": "bar", "type": "int", "required": true}, {"id": 3, "name": "baz", "type": "boolean", "required": false}], "schema-id": 1, "identifier-field-ids": [1]}""" + assert actual == expected + + +def test_deserialize_schema(table_schema_simple: Schema): + actual = Schema.parse_raw( + """{"fields": [{"id": 1, "name": "foo", "type": "string", "required": false}, {"id": 2, "name": "bar", "type": "int", "required": true}, {"id": 3, "name": "baz", "type": "boolean", "required": false}], "schema-id": 1, "identifier-field-ids": [1]}""" + ) + expected = table_schema_simple + assert actual == expected diff --git a/python/tests/test_types.py b/python/tests/test_types.py index 3d75b88a58d6..1695f2657127 100644 --- a/python/tests/test_types.py +++ b/python/tests/test_types.py @@ -26,6 +26,7 @@ DoubleType, FixedType, FloatType, + IcebergType, IntegerType, ListType, LongType, @@ -38,6 +39,7 @@ TimeType, UUIDType, ) +from iceberg.utils.iceberg_base_model import IcebergBaseModel non_parameterized_types = [ (1, BooleanType), @@ -146,9 +148,9 @@ def test_list_type(): ), False, ) - assert isinstance(type_var.element.field_type, StructType) - assert len(type_var.element.field_type.fields) == 2 - assert type_var.element.field_id == 1 + assert isinstance(type_var.element_field.field_type, StructType) + assert len(type_var.element_field.field_type.fields) == 2 + assert type_var.element_field.field_id == 1 assert str(type_var) == str(eval(repr(type_var))) assert type_var == eval(repr(type_var)) assert type_var != ListType( @@ -162,10 +164,10 @@ def test_list_type(): def test_map_type(): type_var = MapType(1, DoubleType(), 2, UUIDType(), False) - assert isinstance(type_var.key.field_type, DoubleType) - assert type_var.key.field_id == 1 - assert isinstance(type_var.value.field_type, UUIDType) - assert type_var.value.field_id == 2 + assert isinstance(type_var.key_field.field_type, DoubleType) + assert type_var.key_field.field_id == 1 + assert isinstance(type_var.value_field.field_type, UUIDType) + assert type_var.value_field.field_id == 2 assert str(type_var) == str(eval(repr(type_var))) assert type_var == eval(repr(type_var)) assert type_var != MapType(1, LongType(), 2, UUIDType(), False) @@ -206,6 +208,390 @@ def test_non_parameterized_type_equality(input_index, input_type, check_index, c assert input_type() != check_type() +# Examples based on https://iceberg.apache.org/spec/#appendix-c-json-serialization + + +class TestType(IcebergBaseModel): + __root__: IcebergType + + +def test_serialization_boolean(): + assert BooleanType().json() == '"boolean"' + + +def test_deserialization_boolean(): + assert TestType.parse_raw('"boolean"') == BooleanType() + + +def test_str_boolean(): + assert str(BooleanType()) == "boolean" + + +def test_repr_boolean(): + assert repr(BooleanType()) == "BooleanType()" + + +def test_serialization_int(): + assert IntegerType().json() == '"int"' + + +def test_deserialization_int(): + assert TestType.parse_raw('"int"') == IntegerType() + + +def test_str_int(): + assert str(IntegerType()) == "int" + + +def test_repr_int(): + assert repr(IntegerType()) == "IntegerType()" + + +def test_serialization_long(): + assert LongType().json() == '"long"' + + +def test_deserialization_long(): + assert TestType.parse_raw('"long"') == LongType() + + +def test_str_long(): + assert str(LongType()) == "long" + + +def test_repr_long(): + assert repr(LongType()) == "LongType()" + + +def test_serialization_float(): + assert FloatType().json() == '"float"' + + +def test_deserialization_float(): + assert TestType.parse_raw('"float"') == FloatType() + + +def test_str_float(): + assert str(FloatType()) == "float" + + +def test_repr_float(): + assert repr(FloatType()) == "FloatType()" + + +def test_serialization_double(): + assert DoubleType().json() == '"double"' + + +def test_deserialization_double(): + assert TestType.parse_raw('"double"') == DoubleType() + + +def test_str_double(): + assert str(DoubleType()) == "double" + + +def test_repr_double(): + assert repr(DoubleType()) == "DoubleType()" + + +def test_serialization_date(): + assert DateType().json() == '"date"' + + +def test_deserialization_date(): + assert TestType.parse_raw('"date"') == DateType() + + +def test_str_date(): + assert str(DateType()) == "date" + + +def test_repr_date(): + assert repr(DateType()) == "DateType()" + + +def test_serialization_time(): + assert TimeType().json() == '"time"' + + +def test_deserialization_time(): + assert TestType.parse_raw('"time"') == TimeType() + + +def test_str_time(): + assert str(TimeType()) == "time" + + +def test_repr_time(): + assert repr(TimeType()) == "TimeType()" + + +def test_serialization_timestamp(): + assert TimestampType().json() == '"timestamp"' + + +def test_deserialization_timestamp(): + assert TestType.parse_raw('"timestamp"') == TimestampType() + + +def test_str_timestamp(): + assert str(TimestampType()) == "timestamp" + + +def test_repr_timestamp(): + assert repr(TimestampType()) == "TimestampType()" + + +def test_serialization_timestamptz(): + assert TimestamptzType().json() == '"timestamptz"' + + +def test_deserialization_timestamptz(): + assert TestType.parse_raw('"timestamptz"') == TimestamptzType() + + +def test_str_timestamptz(): + assert str(TimestamptzType()) == "timestamptz" + + +def test_repr_timestamptz(): + assert repr(TimestamptzType()) == "TimestamptzType()" + + +def test_serialization_string(): + assert StringType().json() == '"string"' + + +def test_deserialization_string(): + assert TestType.parse_raw('"string"') == StringType() + + +def test_str_string(): + assert str(StringType()) == "string" + + +def test_repr_string(): + assert repr(StringType()) == "StringType()" + + +def test_serialization_uuid(): + assert UUIDType().json() == '"uuid"' + + +def test_deserialization_uuid(): + assert TestType.parse_raw('"uuid"') == UUIDType() + + +def test_str_uuid(): + assert str(UUIDType()) == "uuid" + + +def test_repr_uuid(): + assert repr(UUIDType()) == "UUIDType()" + + +def test_serialization_fixed(): + assert FixedType(22).json() == '"fixed[22]"' + + +def test_deserialization_fixed(): + fixed = TestType.parse_raw('"fixed[22]"') + assert fixed == FixedType(22) + + inner = fixed.__root__ + assert isinstance(inner, FixedType) + assert inner.length == 22 + + +def test_str_fixed(): + assert str(FixedType(22)) == "fixed[22]" + + +def test_repr_fixed(): + assert repr(FixedType(22)) == "FixedType(length=22)" + + +def test_serialization_binary(): + assert BinaryType().json() == '"binary"' + + +def test_deserialization_binary(): + assert TestType.parse_raw('"binary"') == BinaryType() + + +def test_str_binary(): + assert str(BinaryType()) == "binary" + + +def test_repr_binary(): + assert repr(BinaryType()) == "BinaryType()" + + +def test_serialization_decimal(): + assert DecimalType(19, 25).json() == '"decimal(19, 25)"' + + +def test_deserialization_decimal(): + decimal = TestType.parse_raw('"decimal(19, 25)"') + assert decimal == DecimalType(19, 25) + + inner = decimal.__root__ + assert isinstance(inner, DecimalType) + assert inner.precision == 19 + assert inner.scale == 25 + + +def test_str_decimal(): + assert str(DecimalType(19, 25)) == "decimal(19, 25)" + + +def test_repr_decimal(): + assert repr(DecimalType(19, 25)) == "DecimalType(precision=19, scale=25)" + + +def test_serialization_nestedfield(): + expected = '{"id": 1, "name": "required_field", "type": "string", "required": true, "doc": "this is a doc"}' + actual = NestedField(1, "required_field", StringType(), True, "this is a doc").json() + assert expected == actual + + +def test_serialization_nestedfield_no_doc(): + expected = '{"id": 1, "name": "required_field", "type": "string", "required": true}' + actual = NestedField(1, "required_field", StringType(), True).json() + assert expected == actual + + +def test_str_nestedfield(): + assert str(NestedField(1, "required_field", StringType(), True)) == "1: required_field: required string" + + +def test_repr_nestedfield(): + assert ( + repr(NestedField(1, "required_field", StringType(), True)) + == "NestedField(field_id=1, name='required_field', field_type=StringType(), required=True)" + ) + + +def test_nestedfield_by_alias(): + # We should be able to initialize a NestedField by alias + expected = NestedField(1, "required_field", StringType(), True, "this is a doc") + actual = NestedField(**{"id": 1, "name": "required_field", "type": "string", "required": True, "doc": "this is a doc"}) + assert expected == actual + + +def test_deserialization_nestedfield(): + expected = NestedField(1, "required_field", StringType(), True, "this is a doc") + actual = NestedField.parse_raw( + '{"id": 1, "name": "required_field", "type": "string", "required": true, "doc": "this is a doc"}' + ) + assert expected == actual + + +def test_deserialization_nestedfield_inner(): + expected = NestedField(1, "required_field", StringType(), True, "this is a doc") + actual = TestType.parse_raw('{"id": 1, "name": "required_field", "type": "string", "required": true, "doc": "this is a doc"}') + assert expected == actual.__root__ + + +def test_serialization_struct(): + actual = StructType( + NestedField(1, "required_field", StringType(), True, "this is a doc"), NestedField(2, "optional_field", IntegerType()) + ).json() + expected = ( + '{"type": "struct", "fields": [' + '{"id": 1, "name": "required_field", "type": "string", "required": true, "doc": "this is a doc"}, ' + '{"id": 2, "name": "optional_field", "type": "int", "required": true}' + "]}" + ) + assert actual == expected + + +def test_deserialization_struct(): + actual = StructType.parse_raw( + """ + { + "type": "struct", + "fields": [{ + "id": 1, + "name": "required_field", + "type": "string", + "required": true, + "doc": "this is a doc" + }, + { + "id": 2, + "name": "optional_field", + "type": "int", + "required": true + } + ] + } + """ + ) + + expected = StructType( + NestedField(1, "required_field", StringType(), True, "this is a doc"), NestedField(2, "optional_field", IntegerType()) + ) + + assert actual == expected + + +def test_str_struct(simple_struct: StructType): + assert str(simple_struct) == "struct<1: required_field: required string (this is a doc), 2: optional_field: required int>" + + +def test_repr_struct(simple_struct: StructType): + assert ( + repr(simple_struct) + == "StructType(fields=(NestedField(field_id=1, name='required_field', field_type=StringType(), required=True), NestedField(field_id=2, name='optional_field', field_type=IntegerType(), required=True),))" + ) + + +def test_serialization_list(simple_list: ListType): + actual = simple_list.json() + expected = '{"type": "list", "element-id": 22, "element": "string", "element-required": true}' + assert actual == expected + + +def test_deserialization_list(simple_list: ListType): + actual = ListType.parse_raw('{"type": "list", "element-id": 22, "element": "string", "element-required": true}') + assert actual == simple_list + + +def test_str_list(simple_list: ListType): + assert str(simple_list) == "list" + + +def test_repr_list(simple_list: ListType): + assert repr(simple_list) == "ListType(type='list', element_id=22, element_type=StringType(), element_required=True)" + + +def test_serialization_map(simple_map: MapType): + actual = simple_map.json() + expected = """{"type": "map", "key-id": 19, "key": "string", "value-id": 25, "value": "double", "value-required": false}""" + + assert actual == expected + + +def test_deserialization_map(simple_map: MapType): + actual = MapType.parse_raw( + """{"type": "map", "key-id": 19, "key": "string", "value-id": 25, "value": "double", "value-required": false}""" + ) + assert actual == simple_map + + +def test_str_map(simple_map: MapType): + assert str(simple_map) == "map" + + +def test_repr_map(simple_map: MapType): + assert ( + repr(simple_map) + == "MapType(type='map', key_id=19, key_type=StringType(), value_id=25, value_type=DoubleType(), value_required=False)" + ) + + def test_types_singleton(): """The types are immutable so we can return the same instance multiple times""" assert id(BooleanType()) == id(BooleanType()) diff --git a/python/tests/utils/test_bin_packing.py b/python/tests/utils/test_bin_packing.py index cf96023a139f..59e5f56fe282 100644 --- a/python/tests/utils/test_bin_packing.py +++ b/python/tests/utils/test_bin_packing.py @@ -19,6 +19,7 @@ import pytest +from iceberg.schema import Schema from iceberg.utils.bin_packing import PackingIterator @@ -81,3 +82,17 @@ def weight_func(x): return x assert list(PackingIterator(splits, target_weight, lookback, weight_func, largest_bin_first)) == expected_lists + + +def test_serialize_schema(table_schema_simple: Schema): + actual = table_schema_simple.json() + expected = """{"fields": [{"id": 1, "name": "foo", "type": "string", "required": false}, {"id": 2, "name": "bar", "type": "int", "required": true}, {"id": 3, "name": "baz", "type": "boolean", "required": false}], "schema-id": 1, "identifier-field-ids": [1]}""" + assert actual == expected + + +def test_deserialize_schema(table_schema_simple: Schema): + actual = Schema.parse_raw( + """{"fields": [{"id": 1, "name": "foo", "type": "string", "required": false}, {"id": 2, "name": "bar", "type": "int", "required": true}, {"id": 3, "name": "baz", "type": "boolean", "required": false}], "schema-id": 1, "identifier-field-ids": [1]}""" + ) + expected = table_schema_simple + assert actual == expected