diff --git a/python/pyiceberg/io/pyarrow.py b/python/pyiceberg/io/pyarrow.py index ae817339689c..f01baefa9273 100644 --- a/python/pyiceberg/io/pyarrow.py +++ b/python/pyiceberg/io/pyarrow.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=redefined-outer-name,arguments-renamed +# pylint: disable=redefined-outer-name,arguments-renamed,fixme """FileIO implementation for reading and writing table files that uses pyarrow.fs This file contains a FileIO implementation that relies on the filesystem interface provided @@ -26,19 +26,23 @@ import multiprocessing import os -from functools import lru_cache +from abc import ABC, abstractmethod +from functools import lru_cache, singledispatch from multiprocessing.pool import ThreadPool from multiprocessing.sharedctypes import Synchronized from typing import ( TYPE_CHECKING, Any, Callable, + Generic, Iterable, List, Optional, Set, Tuple, + TypeVar, Union, + cast, ) from urllib.parse import urlparse @@ -122,6 +126,12 @@ ONE_MEGABYTE = 1024 * 1024 BUFFER_SIZE = "buffer-size" ICEBERG_SCHEMA = b"iceberg.schema" +FIELD_ID = "field_id" +DOC = "doc" +PYARROW_FIELD_ID_KEYS = [b"PARQUET:field_id", b"field_id"] +PYARROW_FIELD_DOC_KEYS = [b"PARQUET:field_doc", b"field_doc", b"doc"] + +T = TypeVar("T") class PyArrowFile(InputFile, OutputFile): @@ -358,14 +368,17 @@ def field(self, field: NestedField, field_result: pa.DataType) -> pa.Field: name=field.name, type=field_result, nullable=field.optional, - metadata={"doc": field.doc, "id": str(field.field_id)} if field.doc else {}, + metadata={DOC: field.doc, FIELD_ID: str(field.field_id)} if field.doc else {FIELD_ID: str(field.field_id)}, ) - def list(self, _: ListType, element_result: pa.DataType) -> pa.DataType: - return pa.list_(value_type=element_result) + def list(self, list_type: ListType, element_result: pa.DataType) -> pa.DataType: + element_field = self.field(list_type.element_field, element_result) + return pa.list_(value_type=element_field) - def map(self, _: MapType, key_result: pa.DataType, value_result: pa.DataType) -> pa.DataType: - return pa.map_(key_type=key_result, item_type=value_result) + def map(self, map_type: MapType, key_result: pa.DataType, value_result: pa.DataType) -> pa.DataType: + key_field = self.field(map_type.key_field, key_result) + value_field = self.field(map_type.value_field, value_result) + return pa.map_(key_type=key_field, item_type=value_field) def visit_fixed(self, fixed_type: FixedType) -> pa.DataType: return pa.binary(len(fixed_type)) @@ -486,6 +499,190 @@ def expression_to_pyarrow(expr: BooleanExpression) -> pc.Expression: return boolean_expression_visit(expr, _ConvertToArrowExpression()) +def pyarrow_to_schema(schema: pa.Schema) -> Schema: + visitor = _ConvertToIceberg() + return visit_pyarrow(schema, visitor) + + +@singledispatch +def visit_pyarrow(obj: pa.DataType | pa.Schema, visitor: PyArrowSchemaVisitor[T]) -> T: + """A generic function for applying a pyarrow schema visitor to any point within a schema + + The function traverses the schema in post-order fashion + + Args: + obj(pa.DataType): An instance of a Schema or an IcebergType + visitor (PyArrowSchemaVisitor[T]): An instance of an implementation of the generic PyarrowSchemaVisitor base class + + Raises: + NotImplementedError: If attempting to visit an unrecognized object type + """ + raise NotImplementedError("Cannot visit non-type: %s" % obj) + + +@visit_pyarrow.register(pa.Schema) +def _(obj: pa.Schema, visitor: PyArrowSchemaVisitor[T]) -> Optional[T]: + struct_results: List[Optional[T]] = [] + for field in obj: + visitor.before_field(field) + struct_result = visit_pyarrow(field.type, visitor) + visitor.after_field(field) + struct_results.append(struct_result) + + return visitor.schema(obj, struct_results) + + +@visit_pyarrow.register(pa.StructType) +def _(obj: pa.StructType, visitor: PyArrowSchemaVisitor[T]) -> Optional[T]: + struct_results: List[Optional[T]] = [] + for field in obj: + visitor.before_field(field) + struct_result = visit_pyarrow(field.type, visitor) + visitor.after_field(field) + struct_results.append(struct_result) + + return visitor.struct(obj, struct_results) + + +@visit_pyarrow.register(pa.ListType) +def _(obj: pa.ListType, visitor: PyArrowSchemaVisitor[T]) -> Optional[T]: + visitor.before_field(obj.value_field) + list_result = visit_pyarrow(obj.value_field.type, visitor) + visitor.after_field(obj.value_field) + return visitor.list(obj, list_result) + + +@visit_pyarrow.register(pa.MapType) +def _(obj: pa.MapType, visitor: PyArrowSchemaVisitor[T]) -> Optional[T]: + visitor.before_field(obj.key_field) + key_result = visit_pyarrow(obj.key_field.type, visitor) + visitor.after_field(obj.key_field) + visitor.before_field(obj.item_field) + value_result = visit_pyarrow(obj.item_field.type, visitor) + visitor.after_field(obj.item_field) + return visitor.map(obj, key_result, value_result) + + +@visit_pyarrow.register(pa.DataType) +def _(obj: pa.DataType, visitor: PyArrowSchemaVisitor[T]) -> Optional[T]: + if pa.types.is_nested(obj): + raise TypeError(f"Expected primitive type, got: {type(obj)}") + return visitor.primitive(obj) + + +class PyArrowSchemaVisitor(Generic[T], ABC): + def before_field(self, field: pa.Field) -> None: + """Override this method to perform an action immediately before visiting a field.""" + + def after_field(self, field: pa.Field) -> None: + """Override this method to perform an action immediately after visiting a field.""" + + @abstractmethod + def schema(self, schema: pa.Schema, field_results: List[Optional[T]]) -> Optional[T]: + """visit a schema""" + + @abstractmethod + def struct(self, struct: pa.StructType, field_results: List[Optional[T]]) -> Optional[T]: + """visit a struct""" + + @abstractmethod + def list(self, list_type: pa.ListType, element_result: Optional[T]) -> Optional[T]: + """visit a list""" + + @abstractmethod + def map(self, map_type: pa.MapType, key_result: Optional[T], value_result: Optional[T]) -> Optional[T]: + """visit a map""" + + @abstractmethod + def primitive(self, primitive: pa.DataType) -> Optional[T]: + """visit a primitive type""" + + +def _get_field_id(field: pa.Field) -> Optional[int]: + for pyarrow_field_id_key in PYARROW_FIELD_ID_KEYS: + if field_id_str := field.metadata.get(pyarrow_field_id_key): + return int(field_id_str.decode()) + return None + + +def _get_field_doc(field: pa.Field) -> Optional[str]: + for pyarrow_doc_key in PYARROW_FIELD_DOC_KEYS: + if doc_str := field.metadata.get(pyarrow_doc_key): + return doc_str.decode() + return None + + +class _ConvertToIceberg(PyArrowSchemaVisitor[Union[IcebergType, Schema]]): + def _convert_fields(self, arrow_fields: Iterable[pa.Field], field_results: List[Optional[IcebergType]]) -> List[NestedField]: + fields = [] + for i, field in enumerate(arrow_fields): + field_id = _get_field_id(field) + field_doc = _get_field_doc(field) + field_type = field_results[i] + if field_type is not None and field_id is not None: + fields.append(NestedField(field_id, field.name, field_type, required=not field.nullable, doc=field_doc)) + return fields + + def schema(self, schema: pa.Schema, field_results: List[Optional[IcebergType]]) -> Schema: + return Schema(*self._convert_fields(schema, field_results)) + + def struct(self, struct: pa.StructType, field_results: List[Optional[IcebergType]]) -> IcebergType: + return StructType(*self._convert_fields(struct, field_results)) + + def list(self, list_type: pa.ListType, element_result: Optional[IcebergType]) -> Optional[IcebergType]: + element_field = list_type.value_field + element_id = _get_field_id(element_field) + if element_result is not None and element_id is not None: + return ListType(element_id, element_result, element_required=not element_field.nullable) + return None + + def map( + self, map_type: pa.MapType, key_result: Optional[IcebergType], value_result: Optional[IcebergType] + ) -> Optional[IcebergType]: + key_field = map_type.key_field + key_id = _get_field_id(key_field) + value_field = map_type.item_field + value_id = _get_field_id(value_field) + if key_result is not None and value_result is not None and key_id is not None and value_id is not None: + return MapType(key_id, key_result, value_id, value_result, value_required=not value_field.nullable) + return None + + def primitive(self, primitive: pa.DataType) -> IcebergType: + if pa.types.is_boolean(primitive): + return BooleanType() + elif pa.types.is_int32(primitive): + return IntegerType() + elif pa.types.is_int64(primitive): + return LongType() + elif pa.types.is_float32(primitive): + return FloatType() + elif pa.types.is_float64(primitive): + return DoubleType() + elif isinstance(primitive, pa.Decimal128Type): + primitive = cast(pa.Decimal128Type, primitive) + return DecimalType(primitive.precision, primitive.scale) + elif pa.types.is_string(primitive): + return StringType() + elif pa.types.is_date32(primitive): + return DateType() + elif isinstance(primitive, pa.Time64Type) and primitive.unit == "us": + return TimeType() + elif pa.types.is_timestamp(primitive): + primitive = cast(pa.TimestampType, primitive) + if primitive.unit == "us": + if primitive.tz == "UTC" or primitive.tz == "+00:00": + return TimestamptzType() + elif primitive.tz is None: + return TimestampType() + elif pa.types.is_binary(primitive): + return BinaryType() + elif pa.types.is_fixed_size_binary(primitive): + primitive = cast(pa.FixedSizeBinaryType, primitive) + return FixedType(primitive.byte_width) + + raise TypeError(f"Unsupported type: {primitive}") + + def _file_to_table( fs: FileSystem, task: FileScanTask, @@ -507,11 +704,9 @@ def _file_to_table( schema_raw = None if metadata := physical_schema.metadata: schema_raw = metadata.get(ICEBERG_SCHEMA) - if schema_raw is None: - raise ValueError( - "Iceberg schema is not embedded into the Parquet file, see https://github.com/apache/iceberg/issues/6505" - ) - file_schema = Schema.parse_raw(schema_raw) + # TODO: if field_ids are not present, Name Mapping should be implemented to look them up in the table schema, + # see https://github.com/apache/iceberg/issues/7451 + file_schema = Schema.parse_raw(schema_raw) if schema_raw is not None else pyarrow_to_schema(physical_schema) pyarrow_filter = None if bound_row_filter is not AlwaysTrue(): diff --git a/python/tests/io/test_pyarrow.py b/python/tests/io/test_pyarrow.py index 7a78b1c0a9d6..ac09bcfb4b36 100644 --- a/python/tests/io/test_pyarrow.py +++ b/python/tests/io/test_pyarrow.py @@ -303,24 +303,58 @@ def test_deleting_s3_file_not_found() -> None: def test_schema_to_pyarrow_schema(table_schema_nested: Schema) -> None: actual = schema_to_pyarrow(table_schema_nested) expected = """foo: string + -- field metadata -- + field_id: '1' bar: int32 not null + -- field metadata -- + field_id: '2' baz: bool -qux: list not null - child 0, item: string + -- field metadata -- + field_id: '3' +qux: list not null + child 0, element: string not null + -- field metadata -- + field_id: '5' + -- field metadata -- + field_id: '4' quux: map> not null - child 0, entries: struct> not null + child 0, entries: struct not null> not null child 0, key: string not null - child 1, value: map - child 0, entries: struct not null + -- field metadata -- + field_id: '7' + child 1, value: map not null + child 0, entries: struct not null child 0, key: string not null - child 1, value: int32 -location: list> not null - child 0, item: struct + -- field metadata -- + field_id: '9' + child 1, value: int32 not null + -- field metadata -- + field_id: '10' + -- field metadata -- + field_id: '8' + -- field metadata -- + field_id: '6' +location: list not null> not null + child 0, element: struct not null child 0, latitude: float + -- field metadata -- + field_id: '13' child 1, longitude: float + -- field metadata -- + field_id: '14' + -- field metadata -- + field_id: '12' + -- field metadata -- + field_id: '11' person: struct child 0, name: string - child 1, age: int32 not null""" + -- field metadata -- + field_id: '16' + child 1, age: int32 not null + -- field metadata -- + field_id: '17' + -- field metadata -- + field_id: '15'""" assert repr(actual) == expected @@ -395,9 +429,9 @@ def test_binary_type_to_pyarrow() -> None: def test_struct_type_to_pyarrow(table_schema_simple: Schema) -> None: expected = pa.struct( [ - pa.field("foo", pa.string(), nullable=True, metadata={"id": "1"}), - pa.field("bar", pa.int32(), nullable=False, metadata={"id": "2"}), - pa.field("baz", pa.bool_(), nullable=True, metadata={"id": "3"}), + pa.field("foo", pa.string(), nullable=True, metadata={"field_id": "1"}), + pa.field("bar", pa.int32(), nullable=False, metadata={"field_id": "2"}), + pa.field("baz", pa.bool_(), nullable=True, metadata={"field_id": "3"}), ] ) assert visit(table_schema_simple.as_struct(), _ConvertToArrowSchema()) == expected @@ -411,7 +445,10 @@ def test_map_type_to_pyarrow() -> None: value_type=StringType(), value_required=True, ) - assert visit(iceberg_map, _ConvertToArrowSchema()) == pa.map_(pa.int32(), pa.string()) + assert visit(iceberg_map, _ConvertToArrowSchema()) == pa.map_( + pa.field("key", pa.int32(), nullable=False, metadata={"field_id": "1"}), + pa.field("value", pa.string(), nullable=False, metadata={"field_id": "2"}), + ) def test_list_type_to_pyarrow() -> None: @@ -420,7 +457,9 @@ def test_list_type_to_pyarrow() -> None: element_type=IntegerType(), element_required=True, ) - assert visit(iceberg_map, _ConvertToArrowSchema()) == pa.list_(pa.int32()) + assert visit(iceberg_map, _ConvertToArrowSchema()) == pa.list_( + pa.field("element", pa.int32(), nullable=False, metadata={"field_id": "1"}) + ) @pytest.fixture @@ -817,19 +856,28 @@ def test_projection_add_column(file_int: str) -> None: for actual, expected in zip(result_table.columns[3], [None, None, None]): assert actual.as_py() == expected - assert ( repr(result_table.schema) == """id: int32 -list: list - child 0, item: int32 +list: list + child 0, element: int32 + -- field metadata -- + field_id: '21' map: map child 0, entries: struct not null child 0, key: int32 not null + -- field metadata -- + field_id: '31' child 1, value: string + -- field metadata -- + field_id: '32' location: struct child 0, lat: double - child 1, lon: double""" + -- field metadata -- + field_id: '41' + child 1, lon: double + -- field metadata -- + field_id: '42'""" ) @@ -873,13 +921,16 @@ def test_projection_add_column_struct(schema_int: Schema, file_int: str) -> None # Everything should be None for r in result_table.columns[0]: assert r.as_py() is None - assert ( repr(result_table.schema) == """id: map child 0, entries: struct not null child 0, key: int32 not null - child 1, value: string""" + -- field metadata -- + field_id: '3' + child 1, value: string + -- field metadata -- + field_id: '4'""" ) @@ -923,7 +974,12 @@ def test_projection_concat_files(schema_int: Schema, file_int: str) -> None: def test_projection_filter(schema_int: Schema, file_int: str) -> None: result_table = project(schema_int, [file_int], GreaterThan("id", 4)) assert len(result_table.columns[0]) == 0 - assert repr(result_table.schema) == "id: int32" + assert ( + repr(result_table.schema) + == """id: int32 + -- field metadata -- + field_id: '1'""" + ) def test_projection_filter_renamed_column(file_int: str) -> None: @@ -1099,7 +1155,11 @@ def test_projection_nested_struct_different_parent_id(file_struct: str) -> None: repr(result_table.schema) == """location: struct child 0, lat: double - child 1, long: double""" + -- field metadata -- + field_id: '41' + child 1, long: double + -- field metadata -- + field_id: '42'""" ) diff --git a/python/tests/io/test_pyarrow_visitor.py b/python/tests/io/test_pyarrow_visitor.py new file mode 100644 index 000000000000..5194d8660e7d --- /dev/null +++ b/python/tests/io/test_pyarrow_visitor.py @@ -0,0 +1,271 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=protected-access,unused-argument,redefined-outer-name +import re + +import pyarrow as pa +import pytest + +from pyiceberg.io.pyarrow import ( + _ConvertToArrowSchema, + _ConvertToIceberg, + pyarrow_to_schema, + schema_to_pyarrow, + visit_pyarrow, +) +from pyiceberg.schema import Schema, visit +from pyiceberg.types import ( + BinaryType, + BooleanType, + DateType, + DecimalType, + DoubleType, + FixedType, + FloatType, + IntegerType, + ListType, + LongType, + MapType, + NestedField, + StringType, + StructType, + TimestampType, + TimestamptzType, + TimeType, +) + + +def test_pyarrow_binary_to_iceberg() -> None: + length = 23 + pyarrow_type = pa.binary(length) + converted_iceberg_type = visit_pyarrow(pyarrow_type, _ConvertToIceberg()) + assert converted_iceberg_type == FixedType(length) + assert visit(converted_iceberg_type, _ConvertToArrowSchema()) == pyarrow_type + + +def test_pyarrow_decimal128_to_iceberg() -> None: + precision = 26 + scale = 20 + pyarrow_type = pa.decimal128(precision, scale) + converted_iceberg_type = visit_pyarrow(pyarrow_type, _ConvertToIceberg()) + assert converted_iceberg_type == DecimalType(precision, scale) + assert visit(converted_iceberg_type, _ConvertToArrowSchema()) == pyarrow_type + + +def test_pyarrow_decimal256_to_iceberg() -> None: + precision = 26 + scale = 20 + pyarrow_type = pa.decimal256(precision, scale) + with pytest.raises(TypeError, match=re.escape("Unsupported type: decimal256(26, 20)")): + visit_pyarrow(pyarrow_type, _ConvertToIceberg()) + + +def test_pyarrow_boolean_to_iceberg() -> None: + pyarrow_type = pa.bool_() + converted_iceberg_type = visit_pyarrow(pyarrow_type, _ConvertToIceberg()) + assert converted_iceberg_type == BooleanType() + assert visit(converted_iceberg_type, _ConvertToArrowSchema()) == pyarrow_type + + +def test_pyarrow_int32_to_iceberg() -> None: + pyarrow_type = pa.int32() + converted_iceberg_type = visit_pyarrow(pyarrow_type, _ConvertToIceberg()) + assert converted_iceberg_type == IntegerType() + assert visit(converted_iceberg_type, _ConvertToArrowSchema()) == pyarrow_type + + +def test_pyarrow_int64_to_iceberg() -> None: + pyarrow_type = pa.int64() + converted_iceberg_type = visit_pyarrow(pyarrow_type, _ConvertToIceberg()) + assert converted_iceberg_type == LongType() + assert visit(converted_iceberg_type, _ConvertToArrowSchema()) == pyarrow_type + + +def test_pyarrow_float32_to_iceberg() -> None: + pyarrow_type = pa.float32() + converted_iceberg_type = visit_pyarrow(pyarrow_type, _ConvertToIceberg()) + assert converted_iceberg_type == FloatType() + assert visit(converted_iceberg_type, _ConvertToArrowSchema()) == pyarrow_type + + +def test_pyarrow_float64_to_iceberg() -> None: + pyarrow_type = pa.float64() + converted_iceberg_type = visit_pyarrow(pyarrow_type, _ConvertToIceberg()) + assert converted_iceberg_type == DoubleType() + assert visit(converted_iceberg_type, _ConvertToArrowSchema()) == pyarrow_type + + +def test_pyarrow_date32_to_iceberg() -> None: + pyarrow_type = pa.date32() + converted_iceberg_type = visit_pyarrow(pyarrow_type, _ConvertToIceberg()) + assert converted_iceberg_type == DateType() + assert visit(converted_iceberg_type, _ConvertToArrowSchema()) == pyarrow_type + + +def test_pyarrow_date64_to_iceberg() -> None: + pyarrow_type = pa.date64() + with pytest.raises(TypeError, match=re.escape("Unsupported type: date64")): + visit_pyarrow(pyarrow_type, _ConvertToIceberg()) + + +def test_pyarrow_time32_to_iceberg() -> None: + pyarrow_type = pa.time32("ms") + with pytest.raises(TypeError, match=re.escape("Unsupported type: time32[ms]")): + visit_pyarrow(pyarrow_type, _ConvertToIceberg()) + pyarrow_type = pa.time32("s") + with pytest.raises(TypeError, match=re.escape("Unsupported type: time32[s]")): + visit_pyarrow(pyarrow_type, _ConvertToIceberg()) + + +def test_pyarrow_time64_us_to_iceberg() -> None: + pyarrow_type = pa.time64("us") + converted_iceberg_type = visit_pyarrow(pyarrow_type, _ConvertToIceberg()) + assert converted_iceberg_type == TimeType() + assert visit(converted_iceberg_type, _ConvertToArrowSchema()) == pyarrow_type + + +def test_pyarrow_time64_ns_to_iceberg() -> None: + pyarrow_type = pa.time64("ns") + with pytest.raises(TypeError, match=re.escape("Unsupported type: time64[ns]")): + visit_pyarrow(pyarrow_type, _ConvertToIceberg()) + + +def test_pyarrow_timestamp_to_iceberg() -> None: + pyarrow_type = pa.timestamp(unit="us") + converted_iceberg_type = visit_pyarrow(pyarrow_type, _ConvertToIceberg()) + assert converted_iceberg_type == TimestampType() + assert visit(converted_iceberg_type, _ConvertToArrowSchema()) == pyarrow_type + + +def test_pyarrow_timestamp_invalid_units() -> None: + pyarrow_type = pa.timestamp(unit="ms") + with pytest.raises(TypeError, match=re.escape("Unsupported type: timestamp[ms]")): + visit_pyarrow(pyarrow_type, _ConvertToIceberg()) + pyarrow_type = pa.timestamp(unit="s") + with pytest.raises(TypeError, match=re.escape("Unsupported type: timestamp[s]")): + visit_pyarrow(pyarrow_type, _ConvertToIceberg()) + pyarrow_type = pa.timestamp(unit="ns") + with pytest.raises(TypeError, match=re.escape("Unsupported type: timestamp[ns]")): + visit_pyarrow(pyarrow_type, _ConvertToIceberg()) + + +def test_pyarrow_timestamp_tz_to_iceberg() -> None: + pyarrow_type = pa.timestamp(unit="us", tz="UTC") + pyarrow_type_zero_offset = pa.timestamp(unit="us", tz="+00:00") + converted_iceberg_type = visit_pyarrow(pyarrow_type, _ConvertToIceberg()) + converted_iceberg_type_zero_offset = visit_pyarrow(pyarrow_type_zero_offset, _ConvertToIceberg()) + assert converted_iceberg_type == TimestamptzType() + assert converted_iceberg_type_zero_offset == TimestamptzType() + assert visit(converted_iceberg_type, _ConvertToArrowSchema()) == pyarrow_type + assert visit(converted_iceberg_type_zero_offset, _ConvertToArrowSchema()) == pyarrow_type + + +def test_pyarrow_timestamp_tz_invalid_units() -> None: + pyarrow_type = pa.timestamp(unit="ms", tz="UTC") + with pytest.raises(TypeError, match=re.escape("Unsupported type: timestamp[ms, tz=UTC]")): + visit_pyarrow(pyarrow_type, _ConvertToIceberg()) + pyarrow_type = pa.timestamp(unit="s", tz="UTC") + with pytest.raises(TypeError, match=re.escape("Unsupported type: timestamp[s, tz=UTC]")): + visit_pyarrow(pyarrow_type, _ConvertToIceberg()) + pyarrow_type = pa.timestamp(unit="ns", tz="UTC") + with pytest.raises(TypeError, match=re.escape("Unsupported type: timestamp[ns, tz=UTC]")): + visit_pyarrow(pyarrow_type, _ConvertToIceberg()) + + +def test_pyarrow_timestamp_tz_invalid_tz() -> None: + pyarrow_type = pa.timestamp(unit="us", tz="US/Pacific") + with pytest.raises(TypeError, match=re.escape("Unsupported type: timestamp[us, tz=US/Pacific]")): + visit_pyarrow(pyarrow_type, _ConvertToIceberg()) + + +def test_pyarrow_string_to_iceberg() -> None: + pyarrow_type = pa.string() + converted_iceberg_type = visit_pyarrow(pyarrow_type, _ConvertToIceberg()) + assert converted_iceberg_type == StringType() + assert visit(converted_iceberg_type, _ConvertToArrowSchema()) == pyarrow_type + + +def test_pyarrow_variable_binary_to_iceberg() -> None: + pyarrow_type = pa.binary() + converted_iceberg_type = visit_pyarrow(pyarrow_type, _ConvertToIceberg()) + assert converted_iceberg_type == BinaryType() + assert visit(converted_iceberg_type, _ConvertToArrowSchema()) == pyarrow_type + + +def test_pyarrow_struct_to_iceberg() -> None: + pyarrow_struct = pa.struct( + [ + pa.field("foo", pa.string(), nullable=True, metadata={"field_id": "1", "doc": "foo doc"}), + pa.field("bar", pa.int32(), nullable=False, metadata={"field_id": "2"}), + pa.field("baz", pa.bool_(), nullable=True, metadata={"field_id": "3"}), + ] + ) + expected = StructType( + NestedField(field_id=1, name="foo", field_type=StringType(), required=False, doc="foo doc"), + NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True), + NestedField(field_id=3, name="baz", field_type=BooleanType(), required=False), + ) + assert visit_pyarrow(pyarrow_struct, _ConvertToIceberg()) == expected + + +def test_pyarrow_list_to_iceberg() -> None: + pyarrow_list = pa.list_(pa.field("element", pa.int32(), nullable=False, metadata={"field_id": "1"})) + expected = ListType( + element_id=1, + element_type=IntegerType(), + element_required=True, + ) + assert visit_pyarrow(pyarrow_list, _ConvertToIceberg()) == expected + + +def test_pyarrow_map_to_iceberg() -> None: + pyarrow_map = pa.map_( + pa.field("key", pa.int32(), nullable=False, metadata={"field_id": "1"}), + pa.field("value", pa.string(), nullable=False, metadata={"field_id": "2"}), + ) + expected = MapType( + key_id=1, + key_type=IntegerType(), + value_id=2, + value_type=StringType(), + value_required=True, + ) + assert visit_pyarrow(pyarrow_map, _ConvertToIceberg()) == expected + + +def test_round_schema_conversion_simple(table_schema_simple: Schema) -> None: + actual = str(pyarrow_to_schema(schema_to_pyarrow(table_schema_simple))) + expected = """table { + 1: foo: optional string + 2: bar: required int + 3: baz: optional boolean +}""" + assert actual == expected + + +def test_round_schema_conversion_nested(table_schema_nested: Schema) -> None: + actual = str(pyarrow_to_schema(schema_to_pyarrow(table_schema_nested))) + expected = """table { + 1: foo: optional string + 2: bar: required int + 3: baz: optional boolean + 4: qux: required list + 6: quux: required map> + 11: location: required list> + 15: person: optional struct<16: name: optional string, 17: age: required int> +}""" + assert actual == expected diff --git a/python/tests/test_integration.py b/python/tests/test_integration.py index 3eb24cd48e62..5577f3d84a2b 100644 --- a/python/tests/test_integration.py +++ b/python/tests/test_integration.py @@ -17,11 +17,16 @@ # pylint:disable=redefined-outer-name import math +from urllib.parse import urlparse +import pyarrow.parquet as pq import pytest +from pyarrow.fs import S3FileSystem from pyiceberg.catalog import Catalog, load_catalog from pyiceberg.expressions import IsNaN, NotNaN +from pyiceberg.io.pyarrow import pyarrow_to_schema +from pyiceberg.schema import Schema from pyiceberg.table import Table @@ -135,3 +140,22 @@ def test_ray_all_types(table_test_all_types: Table) -> None: pandas_dataframe = table_test_all_types.scan().to_pandas() assert ray_dataset.count() == pandas_dataframe.shape[0] assert pandas_dataframe.equals(ray_dataset.to_pandas()) + + +@pytest.mark.integration +def test_pyarrow_to_iceberg_all_types(table_test_all_types: Table) -> None: + fs = S3FileSystem( + **{ + "endpoint_override": "http://localhost:9000", + "access_key": "admin", + "secret_key": "password", + } + ) + data_file_paths = [task.file.file_path for task in table_test_all_types.scan().plan_files()] + for data_file_path in data_file_paths: + uri = urlparse(data_file_path) + with fs.open_input_file(f"{uri.netloc}{uri.path}") as fout: + parquet_schema = pq.read_schema(fout) + stored_iceberg_schema = Schema.parse_raw(parquet_schema.metadata.get(b"iceberg.schema")) + converted_iceberg_schema = pyarrow_to_schema(parquet_schema) + assert converted_iceberg_schema == stored_iceberg_schema