diff --git a/python/pyiceberg/avro/__init__.py b/python/pyiceberg/avro/__init__.py index 13a83393a912..d7d8b55ef913 100644 --- a/python/pyiceberg/avro/__init__.py +++ b/python/pyiceberg/avro/__init__.py @@ -14,3 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import struct + +STRUCT_FLOAT = struct.Struct("h") # big-endian signed short -STRUCT_SIGNED_INT = struct.Struct(">i") # big-endian signed int -STRUCT_SIGNED_LONG = struct.Struct(">q") # big-endian signed long - class BinaryDecoder: """Read leaf values.""" diff --git a/python/pyiceberg/avro/encoder.py b/python/pyiceberg/avro/encoder.py new file mode 100644 index 000000000000..cf6d60123357 --- /dev/null +++ b/python/pyiceberg/avro/encoder.py @@ -0,0 +1,175 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import decimal +import struct +from datetime import date, datetime, time + +from pyiceberg.avro import STRUCT_DOUBLE, STRUCT_FLOAT +from pyiceberg.io import OutputStream +from pyiceberg.utils.datetime import date_to_days, datetime_to_micros, time_object_to_micros + + +class BinaryEncoder: + """Write leaf values.""" + + _output_stream: OutputStream + + def __init__(self, output_stream: OutputStream) -> None: + self._output_stream = output_stream + + def write(self, b: bytes) -> None: + self._output_stream.write(b) + + def write_boolean(self, boolean: bool) -> None: + """A boolean is written as a single byte whose value is either 0 (false) or 1 (true). + + Args: + boolean: The boolean to write. + """ + self.write(bytearray([bool(boolean)])) + + def write_int(self, integer: int) -> None: + """Integer and long values are written using variable-length zig-zag coding.""" + datum = (integer << 1) ^ (integer >> 63) + while (datum & ~0x7F) != 0: + self.write(bytearray([(datum & 0x7F) | 0x80])) + datum >>= 7 + self.write(bytearray([datum])) + + def write_float(self, f: float) -> None: + """A float is written as 4 bytes.""" + self.write(STRUCT_FLOAT.pack(f)) + + def write_double(self, f: float) -> None: + """A double is written as 8 bytes.""" + self.write(STRUCT_DOUBLE.pack(f)) + + def write_decimal_bytes(self, datum: decimal.Decimal) -> None: + """ + Decimal in bytes are encoded as long. + + Since size of packed value in bytes for signed long is 8, 8 bytes are written. + """ + sign, digits, _ = datum.as_tuple() + + unscaled_datum = 0 + for digit in digits: + unscaled_datum = (unscaled_datum * 10) + digit + + bits_req = unscaled_datum.bit_length() + 1 + if sign: + unscaled_datum = (1 << bits_req) - unscaled_datum + + bytes_req = bits_req // 8 + padding_bits = ~((1 << bits_req) - 1) if sign else 0 + packed_bits = padding_bits | unscaled_datum + + bytes_req += 1 if (bytes_req << 3) < bits_req else 0 + self.write_int(bytes_req) + for index in range(bytes_req - 1, -1, -1): + bits_to_write = packed_bits >> (8 * index) + self.write(bytearray([bits_to_write & 0xFF])) + + def write_decimal_fixed(self, datum: decimal.Decimal, size: int) -> None: + """Decimal in fixed are encoded as size of fixed bytes.""" + sign, digits, _ = datum.as_tuple() + + unscaled_datum = 0 + for digit in digits: + unscaled_datum = (unscaled_datum * 10) + digit + + bits_req = unscaled_datum.bit_length() + 1 + size_in_bits = size * 8 + offset_bits = size_in_bits - bits_req + + mask = 2**size_in_bits - 1 + bit = 1 + for _ in range(bits_req): + mask ^= bit + bit <<= 1 + + if bits_req < 8: + bytes_req = 1 + else: + bytes_req = bits_req // 8 + if bits_req % 8 != 0: + bytes_req += 1 + if sign: + unscaled_datum = (1 << bits_req) - unscaled_datum + unscaled_datum = mask | unscaled_datum + for index in range(size - 1, -1, -1): + bits_to_write = unscaled_datum >> (8 * index) + self.write(bytearray([bits_to_write & 0xFF])) + else: + for _ in range(offset_bits // 8): + self.write(b"\x00") + for index in range(bytes_req - 1, -1, -1): + bits_to_write = unscaled_datum >> (8 * index) + self.write(bytearray([bits_to_write & 0xFF])) + + def write_bytes(self, b: bytes) -> None: + """Bytes are encoded as a long followed by that many bytes of data.""" + self.write_int(len(b)) + self.write(struct.pack(f"{len(b)}s", b)) + + def write_bytes_fixed(self, b: bytes) -> None: + """Writes fixed number of bytes.""" + self.write(struct.pack(f"{len(b)}s", b)) + + def write_utf8(self, s: str) -> None: + """A string is encoded as a long followed by that many bytes of UTF-8 encoded character data.""" + self.write_bytes(s.encode("utf-8")) + + def write_date_int(self, d: date) -> None: + """ + Encode python date object as int. + + It stores the number of days from the unix epoch, 1 January 1970 (ISO calendar). + """ + self.write_int(date_to_days(d)) + + def write_time_millis_int(self, dt: time) -> None: + """ + Encode python time object as int. + + It stores the number of milliseconds from midnight, 00:00:00.000 + """ + self.write_int(int(time_object_to_micros(dt) / 1000)) + + def write_time_micros_long(self, dt: time) -> None: + """ + Encode python time object as long. + + It stores the number of microseconds from midnight, 00:00:00.000000 + """ + self.write_int(time_object_to_micros(dt)) + + def write_timestamp_millis_long(self, dt: datetime) -> None: + """ + Encode python datetime object as long. + + It stores the number of milliseconds from midnight of unix epoch, 1 January 1970. + """ + self.write_int(int(datetime_to_micros(dt) / 1000)) + + def write_timestamp_micros_long(self, dt: datetime) -> None: + """ + Encode python datetime object as long. + + It stores the number of microseconds from midnight of unix epoch, 1 January 1970. + """ + self.write_int(datetime_to_micros(dt)) diff --git a/python/pyiceberg/avro/file.py b/python/pyiceberg/avro/file.py index 5c408edd563f..10f7ef7d7de2 100644 --- a/python/pyiceberg/avro/file.py +++ b/python/pyiceberg/avro/file.py @@ -18,7 +18,9 @@ """Avro reader for reading Avro files.""" from __future__ import annotations +import io import json +import os from dataclasses import dataclass from enum import Enum from types import TracebackType @@ -26,6 +28,7 @@ Callable, Dict, Generic, + List, Optional, Type, TypeVar, @@ -33,9 +36,16 @@ from pyiceberg.avro.codecs import KNOWN_CODECS, Codec from pyiceberg.avro.decoder import BinaryDecoder +from pyiceberg.avro.encoder import BinaryEncoder from pyiceberg.avro.reader import Reader -from pyiceberg.avro.resolver import construct_reader, resolve -from pyiceberg.io import InputFile, InputStream +from pyiceberg.avro.resolver import construct_reader, construct_writer, resolve +from pyiceberg.avro.writer import Writer +from pyiceberg.io import ( + InputFile, + InputStream, + OutputFile, + OutputStream, +) from pyiceberg.io.memory import MemoryInputStream from pyiceberg.schema import Schema from pyiceberg.typedef import EMPTY_DICT, Record, StructProtocol @@ -147,7 +157,7 @@ def __enter__(self) -> AvroFile[D]: """Generates a reader tree for the payload within an avro file. Returns: - A generator returning the AvroStructs + A generator returning the AvroStructs. """ self.input_stream = self.input_file.open(seekable=False) self.decoder = BinaryDecoder(self.input_stream) @@ -204,3 +214,60 @@ def __next__(self) -> D: def _read_header(self) -> AvroFileHeader: return construct_reader(META_SCHEMA, {-1: AvroFileHeader}).read(self.decoder) + + +class AvroOutputFile(Generic[D]): + output_file: OutputFile + output_stream: OutputStream + schema: Schema + schema_name: str + encoder: BinaryEncoder + sync_bytes: bytes + writer: Writer + + def __init__(self, output_file: OutputFile, schema: Schema, schema_name: str, metadata: Dict[str, str] = EMPTY_DICT) -> None: + self.output_file = output_file + self.schema = schema + self.schema_name = schema_name + self.sync_bytes = os.urandom(SYNC_SIZE) + self.writer = construct_writer(self.schema) + self.metadata = metadata + + def __enter__(self) -> AvroOutputFile[D]: + """ + Opens the file and writes the header. + + Returns: + The file object to write records to + """ + self.output_stream = self.output_file.create(overwrite=True) + self.encoder = BinaryEncoder(self.output_stream) + + self._write_header() + self.writer = construct_writer(self.schema) + + return self + + def __exit__( + self, exctype: Optional[Type[BaseException]], excinst: Optional[BaseException], exctb: Optional[TracebackType] + ) -> None: + """Performs cleanup when exiting the scope of a 'with' statement.""" + self.output_stream.close() + + def _write_header(self) -> None: + json_schema = json.dumps(AvroSchemaConversion().iceberg_to_avro(self.schema, schema_name=self.schema_name)) + meta = {**self.metadata, _SCHEMA_KEY: json_schema, _CODEC_KEY: "null"} + header = AvroFileHeader(magic=MAGIC, meta=meta, sync=self.sync_bytes) + construct_writer(META_SCHEMA).write(self.encoder, header) + + def write_block(self, objects: List[D]) -> None: + in_memory = io.BytesIO() + block_content_encoder = BinaryEncoder(output_stream=in_memory) + for obj in objects: + self.writer.write(block_content_encoder, obj) + block_content = in_memory.getvalue() + + self.encoder.write_int(len(objects)) + self.encoder.write_int(len(block_content)) + self.encoder.write(block_content) + self.encoder.write_bytes_fixed(self.sync_bytes) diff --git a/python/pyiceberg/avro/resolver.py b/python/pyiceberg/avro/resolver.py index 407ce4648df1..51505301ea01 100644 --- a/python/pyiceberg/avro/resolver.py +++ b/python/pyiceberg/avro/resolver.py @@ -48,12 +48,34 @@ TimestamptzReader, UUIDReader, ) +from pyiceberg.avro.writer import ( + BinaryWriter, + BooleanWriter, + DateWriter, + DecimalWriter, + DoubleWriter, + FixedWriter, + FloatWriter, + IntegerWriter, + ListWriter, + MapWriter, + OptionWriter, + StringWriter, + StructWriter, + TimestamptzWriter, + TimestampWriter, + TimeWriter, + UUIDWriter, + Writer, +) from pyiceberg.exceptions import ResolveError from pyiceberg.schema import ( PartnerAccessor, PrimitiveWithPartnerVisitor, Schema, + SchemaVisitorPerPrimitiveType, promote, + visit, visit_with_partner, ) from pyiceberg.typedef import EMPTY_DICT, Record, StructProtocol @@ -97,6 +119,79 @@ def construct_reader( return resolve(file_schema, file_schema, read_types) +def construct_writer(file_schema: Union[Schema, IcebergType]) -> Writer: + """Constructs a writer from a file schema. + + Args: + file_schema (Schema | IcebergType): The schema of the Avro file. + + Raises: + NotImplementedError: If attempting to resolve an unrecognized object type. + """ + return visit(file_schema, ConstructWriter()) + + +class ConstructWriter(SchemaVisitorPerPrimitiveType[Writer]): + """Constructs a writer tree from an Iceberg schema.""" + + def schema(self, schema: Schema, struct_result: Writer) -> Writer: + return struct_result + + def struct(self, struct: StructType, field_results: List[Writer]) -> Writer: + return StructWriter(tuple(field_results)) + + def field(self, field: NestedField, field_result: Writer) -> Writer: + return field_result if field.required else OptionWriter(field_result) + + def list(self, list_type: ListType, element_result: Writer) -> Writer: + return ListWriter(element_result) + + def map(self, map_type: MapType, key_result: Writer, value_result: Writer) -> Writer: + return MapWriter(key_result, value_result) + + def visit_fixed(self, fixed_type: FixedType) -> Writer: + return FixedWriter(len(fixed_type)) + + def visit_decimal(self, decimal_type: DecimalType) -> Writer: + return DecimalWriter(decimal_type.precision, decimal_type.scale) + + def visit_boolean(self, boolean_type: BooleanType) -> Writer: + return BooleanWriter() + + def visit_integer(self, integer_type: IntegerType) -> Writer: + return IntegerWriter() + + def visit_long(self, long_type: LongType) -> Writer: + return IntegerWriter() + + def visit_float(self, float_type: FloatType) -> Writer: + return FloatWriter() + + def visit_double(self, double_type: DoubleType) -> Writer: + return DoubleWriter() + + def visit_date(self, date_type: DateType) -> Writer: + return DateWriter() + + def visit_time(self, time_type: TimeType) -> Writer: + return TimeWriter() + + def visit_timestamp(self, timestamp_type: TimestampType) -> Writer: + return TimestampWriter() + + def visit_timestamptz(self, timestamptz_type: TimestamptzType) -> Writer: + return TimestamptzWriter() + + def visit_string(self, string_type: StringType) -> Writer: + return StringWriter() + + def visit_uuid(self, uuid_type: UUIDType) -> Writer: + return UUIDWriter() + + def visit_binary(self, binary_type: BinaryType) -> Writer: + return BinaryWriter() + + def resolve( file_schema: Union[Schema, IcebergType], read_schema: Union[Schema, IcebergType], @@ -106,7 +201,7 @@ def resolve( """Resolves the file and read schema to produce a reader. Args: - file_schema (Schema | IcebergType): The schema of the Avro file + file_schema (Schema | IcebergType): The schema of the Avro file. read_schema (Schema | IcebergType): The requested read schema which is equal, subset or superset of the file schema. read_types (Dict[int, Callable[..., StructProtocol]]): A dict of types to use for struct data. read_enums (Dict[int, Callable[..., Enum]]): A dict of fields that have to be converted to an enum. @@ -249,7 +344,7 @@ def visit_time(self, time_type: TimeType, partner: Optional[IcebergType]) -> Rea def visit_timestamp(self, timestamp_type: TimestampType, partner: Optional[IcebergType]) -> Reader: return TimestampReader() - def visit_timestampz(self, timestamptz_type: TimestamptzType, partner: Optional[IcebergType]) -> Reader: + def visit_timestamptz(self, timestamptz_type: TimestamptzType, partner: Optional[IcebergType]) -> Reader: return TimestamptzReader() def visit_string(self, string_type: StringType, partner: Optional[IcebergType]) -> Reader: diff --git a/python/pyiceberg/avro/writer.py b/python/pyiceberg/avro/writer.py new file mode 100644 index 000000000000..10a589715dcc --- /dev/null +++ b/python/pyiceberg/avro/writer.py @@ -0,0 +1,209 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Classes for building the Reader tree. + +Constructing a reader tree from the schema makes it easy +to decouple the reader implementation from the schema. + +The reader tree can be changed in such a way that the +read schema is different, while respecting the read schema. +""" +from __future__ import annotations + +from abc import abstractmethod +from dataclasses import dataclass +from dataclasses import field as dataclassfield +from datetime import datetime, time +from typing import ( + Any, + Dict, + List, + Tuple, +) +from uuid import UUID + +from pyiceberg.avro.encoder import BinaryEncoder +from pyiceberg.types import StructType +from pyiceberg.utils.singleton import Singleton + + +class Writer(Singleton): + @abstractmethod + def write(self, encoder: BinaryEncoder, val: Any) -> Any: + ... + + def __repr__(self) -> str: + """Returns string representation of this object.""" + return f"{self.__class__.__name__}()" + + +class NoneWriter(Writer): + def write(self, _: BinaryEncoder, __: Any) -> None: + pass + + +class BooleanWriter(Writer): + def write(self, encoder: BinaryEncoder, val: bool) -> None: + encoder.write_boolean(val) + + +class IntegerWriter(Writer): + """Longs and ints are encoded the same way, and there is no long in Python.""" + + def write(self, encoder: BinaryEncoder, val: int) -> None: + encoder.write_int(val) + + +class FloatWriter(Writer): + def write(self, encoder: BinaryEncoder, val: float) -> None: + encoder.write_float(val) + + +class DoubleWriter(Writer): + def write(self, encoder: BinaryEncoder, val: float) -> None: + encoder.write_double(val) + + +class DateWriter(Writer): + def write(self, encoder: BinaryEncoder, val: Any) -> None: + encoder.write_date_int(val) + + +class TimeWriter(Writer): + def write(self, encoder: BinaryEncoder, val: time) -> None: + encoder.write_time_micros_long(val) + + +class TimestampWriter(Writer): + def write(self, encoder: BinaryEncoder, val: datetime) -> None: + encoder.write_timestamp_micros_long(val) + + +class TimestamptzWriter(Writer): + def write(self, encoder: BinaryEncoder, val: datetime) -> None: + encoder.write_timestamp_micros_long(val) + + +class StringWriter(Writer): + def write(self, encoder: BinaryEncoder, val: Any) -> None: + encoder.write_utf8(val) + + +class UUIDWriter(Writer): + def write(self, encoder: BinaryEncoder, val: UUID) -> None: + uuid_bytes = val.bytes + + if len(uuid_bytes) != 16: + raise ValueError(f"Expected UUID to be 16 bytes, got: {len(uuid_bytes)}") + + encoder.write_bytes_fixed(uuid_bytes) + + +@dataclass(frozen=True) +class FixedWriter(Writer): + _len: int = dataclassfield() + + def write(self, encoder: BinaryEncoder, val: bytes) -> None: + encoder.write(val) + + def __len__(self) -> int: + """Returns the length of this object.""" + return self._len + + def __repr__(self) -> str: + """Returns string representation of this object.""" + return f"FixedReader({self._len})" + + +class BinaryWriter(Writer): + """Variable byte length writer.""" + + def write(self, encoder: BinaryEncoder, val: Any) -> None: + encoder.write_bytes(val) + + +@dataclass(frozen=True) +class DecimalWriter(Writer): + precision: int = dataclassfield() + scale: int = dataclassfield() + + def write(self, encoder: BinaryEncoder, val: Any) -> None: + return encoder.write_decimal_bytes(val) + + def __repr__(self) -> str: + """Returns string representation of this object.""" + return f"DecimalReader({self.precision}, {self.scale})" + + +@dataclass(frozen=True) +class OptionWriter(Writer): + option: Writer = dataclassfield() + + def write(self, encoder: BinaryEncoder, val: Any) -> None: + if val is not None: + encoder.write_int(1) + self.option.write(encoder, val) + else: + encoder.write_int(0) + + +@dataclass(frozen=True) +class StructWriter(Writer): + field_writers: Tuple[Writer, ...] = dataclassfield() + + def write(self, encoder: BinaryEncoder, val: StructType) -> None: + for writer, value in zip(self.field_writers, val.record_fields()): + writer.write(encoder, value) + + def __eq__(self, other: Any) -> bool: + """Implements the equality operator for this object.""" + return self.field_writers == other.field_writers if isinstance(other, StructWriter) else False + + def __repr__(self) -> str: + """Returns string representation of this object.""" + return f"StructReader({','.join(repr(field) for field in self.field_writers)})" + + def __hash__(self) -> int: + """Returns the hash of the writer as hash of this object.""" + return hash(self.field_writers) + + +@dataclass(frozen=True) +class ListWriter(Writer): + element_writer: Writer + + def write(self, encoder: BinaryEncoder, val: List[Any]) -> None: + encoder.write_int(len(val)) + for v in val: + self.element_writer.write(encoder, v) + if len(val) > 0: + encoder.write_int(0) + + +@dataclass(frozen=True) +class MapWriter(Writer): + key_writer: Writer + value_writer: Writer + + def write(self, encoder: BinaryEncoder, val: Dict[Any, Any]) -> None: + encoder.write_int(len(val)) + for k, v in val.items(): + self.key_writer.write(encoder, k) + self.value_writer.write(encoder, v) + if len(val) > 0: + encoder.write_int(0) diff --git a/python/pyiceberg/io/pyarrow.py b/python/pyiceberg/io/pyarrow.py index 89ddf805eb73..df2c7f5b64de 100644 --- a/python/pyiceberg/io/pyarrow.py +++ b/python/pyiceberg/io/pyarrow.py @@ -417,7 +417,7 @@ def visit_time(self, _: TimeType) -> pa.DataType: def visit_timestamp(self, _: TimestampType) -> pa.DataType: return pa.timestamp(unit="us") - def visit_timestampz(self, _: TimestamptzType) -> pa.DataType: + def visit_timestamptz(self, _: TimestamptzType) -> pa.DataType: return pa.timestamp(unit="us", tz="UTC") def visit_string(self, _: StringType) -> pa.DataType: diff --git a/python/pyiceberg/schema.py b/python/pyiceberg/schema.py index ea512078dffb..32a8370ac80f 100644 --- a/python/pyiceberg/schema.py +++ b/python/pyiceberg/schema.py @@ -432,7 +432,7 @@ def primitive(self, primitive: PrimitiveType, primitive_partner: Optional[P]) -> elif isinstance(primitive, TimestampType): return self.visit_timestamp(primitive, primitive_partner) elif isinstance(primitive, TimestamptzType): - return self.visit_timestampz(primitive, primitive_partner) + return self.visit_timestamptz(primitive, primitive_partner) elif isinstance(primitive, StringType): return self.visit_string(primitive, primitive_partner) elif isinstance(primitive, UUIDType): @@ -481,7 +481,7 @@ def visit_timestamp(self, timestamp_type: TimestampType, partner: Optional[P]) - """Visit a TimestampType.""" @abstractmethod - def visit_timestampz(self, timestamptz_type: TimestamptzType, partner: Optional[P]) -> T: + def visit_timestamptz(self, timestamptz_type: TimestamptzType, partner: Optional[P]) -> T: """Visit a TimestamptzType.""" @abstractmethod @@ -610,7 +610,7 @@ def primitive(self, primitive: PrimitiveType) -> T: elif isinstance(primitive, TimestampType): return self.visit_timestamp(primitive) elif isinstance(primitive, TimestamptzType): - return self.visit_timestampz(primitive) + return self.visit_timestamptz(primitive) elif isinstance(primitive, StringType): return self.visit_string(primitive) elif isinstance(primitive, UUIDType): @@ -661,7 +661,7 @@ def visit_timestamp(self, timestamp_type: TimestampType) -> T: """Visit a TimestampType.""" @abstractmethod - def visit_timestampz(self, timestamptz_type: TimestamptzType) -> T: + def visit_timestamptz(self, timestamptz_type: TimestamptzType) -> T: """Visit a TimestamptzType.""" @abstractmethod @@ -1367,12 +1367,3 @@ def _(file_type: DecimalType, read_type: IcebergType) -> IcebergType: raise ResolveError(f"Cannot reduce precision from {file_type} to {read_type}") else: raise ResolveError(f"Cannot promote an decimal to {read_type}") - - -@promote.register(FixedType) -def _(file_type: FixedType, read_type: IcebergType) -> IcebergType: - if isinstance(read_type, UUIDType) and len(file_type) == 16: - # Since pyarrow reads parquet UUID as fixed 16-byte binary, the promotion is needed to ensure read compatibility - return read_type - else: - raise ResolveError(f"Cannot promote {file_type} to {read_type}") diff --git a/python/pyiceberg/typedef.py b/python/pyiceberg/typedef.py index 184b788c880c..2e4d1b938fd6 100644 --- a/python/pyiceberg/typedef.py +++ b/python/pyiceberg/typedef.py @@ -24,6 +24,7 @@ Any, Callable, Dict, + List, Optional, Protocol, Set, @@ -161,3 +162,6 @@ def __eq__(self, other: Any) -> bool: def __repr__(self) -> str: """Returns the string representation of the Record class.""" return f"{self.__class__.__name__}[{', '.join(f'{key}={repr(value)}' for key, value in self.__dict__.items() if not key.startswith('_'))}]" + + def record_fields(self) -> List[str]: + return [self.__getattribute__(v) if hasattr(self, v) else None for v in self._position_to_field_name.values()] diff --git a/python/pyiceberg/utils/datetime.py b/python/pyiceberg/utils/datetime.py index 918a22ddbfb8..fd9f6c677609 100644 --- a/python/pyiceberg/utils/datetime.py +++ b/python/pyiceberg/utils/datetime.py @@ -67,6 +67,11 @@ def time_to_micros(time_str: str) -> int: return (((t.hour * 60 + t.minute) * 60) + t.second) * 1_000_000 + t.microsecond +def time_object_to_micros(t: time) -> int: + """Converts an datetime.time object to microseconds from midnight.""" + return int(t.hour * 60 * 60 * 1e6 + t.minute * 60 * 1e6 + t.second * 1e6 + t.microsecond) + + def datetime_to_micros(dt: datetime) -> int: """Converts a datetime to microseconds from 1970-01-01T00:00:00.000000.""" if dt.tzinfo: diff --git a/python/pyiceberg/utils/schema_conversion.py b/python/pyiceberg/utils/schema_conversion.py index 0a1aaabc02ee..4f46668866da 100644 --- a/python/pyiceberg/utils/schema_conversion.py +++ b/python/pyiceberg/utils/schema_conversion.py @@ -20,11 +20,12 @@ Any, Dict, List, + Optional, Tuple, Union, ) -from pyiceberg.schema import Schema +from pyiceberg.schema import Schema, SchemaVisitorPerPrimitiveType, visit from pyiceberg.types import ( BinaryType, BooleanType, @@ -43,6 +44,7 @@ StringType, StructType, TimestampType, + TimestamptzType, TimeType, UUIDType, ) @@ -69,6 +71,8 @@ ("uuid", "fixed"): UUIDType(), } +AvroType = Union[str, Any] + class AvroSchemaConversion: def avro_to_iceberg(self, avro_schema: Dict[str, Any]) -> Schema: @@ -116,6 +120,10 @@ def avro_to_iceberg(self, avro_schema: Dict[str, Any]) -> Schema: """ return Schema(*[self._convert_field(field) for field in avro_schema["fields"]], schema_id=1) + def iceberg_to_avro(self, schema: Schema, schema_name: Optional[str] = None) -> AvroType: + """Converts an Iceberg schema into an Avro dictionary that can be serialized to JSON.""" + return visit(schema, ConvertSchemaToAvro(schema_name)) + def _resolve_union( self, type_union: Union[Dict[str, str], List[Union[str, Dict[str, str]]], str] ) -> Tuple[Union[str, Dict[str, Any]], bool]: @@ -468,3 +476,129 @@ def _convert_fixed_type(self, avro_type: Dict[str, Any]) -> FixedType: An Iceberg equivalent fixed type. """ return FixedType(length=avro_type["size"]) + + +class ConvertSchemaToAvro(SchemaVisitorPerPrimitiveType[AvroType]): + """Converts an Iceberg schema to an Avro schema.""" + + schema_name: Optional[str] + last_list_field_id: int + last_map_key_field_id: int + last_map_value_field_id: int + + def __init__(self, schema_name: Optional[str]) -> None: + """Converts an Iceberg schema to an Avro schema. + + Args: + schema_name: The name of the root record. + """ + self.schema_name = schema_name + + def schema(self, schema: Schema, struct_result: AvroType) -> AvroType: + if isinstance(struct_result, dict) and self.schema_name is not None: + struct_result["name"] = self.schema_name + return struct_result + + def before_list_element(self, element: NestedField) -> None: + self.last_list_field_id = element.field_id + + def before_map_key(self, key: NestedField) -> None: + self.last_map_key_field_id = key.field_id + + def before_map_value(self, value: NestedField) -> None: + self.last_map_value_field_id = value.field_id + + def struct(self, struct: StructType, field_results: List[AvroType]) -> AvroType: + return {"type": "record", "fields": field_results} + + def field(self, field: NestedField, field_result: AvroType) -> AvroType: + # Sets the schema name + if isinstance(field_result, dict) and field_result.get("type") == "record": + field_result["name"] = f"r{field.field_id}" + + result = { + "name": field.name, + "field-id": field.field_id, + "type": field_result if field.required else ["null", field_result], + } + + if field.optional: + result["default"] = None + + if field.doc is not None: + result["doc"] = field.doc + + return result + + def list(self, list_type: ListType, element_result: AvroType) -> AvroType: + # Sets the schema name in case of a record + if isinstance(element_result, dict) and element_result.get("type") == "record": + element_result["name"] = f"r{self.last_list_field_id}" + return {"type": "array", "element-id": self.last_list_field_id, "items": element_result} + + def map(self, map_type: MapType, key_result: AvroType, value_result: AvroType) -> AvroType: + if isinstance(key_result, StringType): + # Avro Maps does not support other keys than a String, + return { + "type": "map", + "values": value_result, + } + else: + # Creates a logical map that's a list of schema's + # binary compatible + return { + "type": "array", + "items": { + "type": "record", + "name": f"k{self.last_map_key_field_id}_v{self.last_map_value_field_id}", + "fields": [ + {"name": "key", "type": key_result, "field-id": self.last_map_key_field_id}, + {"name": "value", "type": value_result, "field-id": self.last_map_value_field_id}, + ], + }, + "logicalType": "map", + } + + def visit_fixed(self, fixed_type: FixedType) -> AvroType: + return {"type": "fixed", "size": len(fixed_type)} + + def visit_decimal(self, decimal_type: DecimalType) -> AvroType: + return {"type": "bytes", "logicalType": "decimal", "precision": decimal_type.precision, "scale": decimal_type.scale} + + def visit_boolean(self, boolean_type: BooleanType) -> AvroType: + return "boolean" + + def visit_integer(self, integer_type: IntegerType) -> AvroType: + return "int" + + def visit_long(self, long_type: LongType) -> AvroType: + return "long" + + def visit_float(self, float_type: FloatType) -> AvroType: + return "float" + + def visit_double(self, double_type: DoubleType) -> AvroType: + return "double" + + def visit_date(self, date_type: DateType) -> AvroType: + return {"type": "int", "logicalType": "date"} + + def visit_time(self, time_type: TimeType) -> AvroType: + return {"type": "long", "logicalType": "time-micros"} + + def visit_timestamp(self, timestamp_type: TimestampType) -> AvroType: + # Iceberg only supports micro's + return {"type": "long", "logicalType": "timestamp-micros"} + + def visit_timestamptz(self, timestamptz_type: TimestamptzType) -> AvroType: + # Iceberg only supports micro's + return {"type": "long", "logicalType": "timestamp-micros"} + + def visit_string(self, string_type: StringType) -> AvroType: + return "string" + + def visit_uuid(self, uuid_type: UUIDType) -> AvroType: + return {"type": "string", "logicalType": "uuid"} + + def visit_binary(self, binary_type: BinaryType) -> AvroType: + return "bytes" diff --git a/python/tests/avro/test_encoder.py b/python/tests/avro/test_encoder.py new file mode 100644 index 000000000000..4646e65e6e61 --- /dev/null +++ b/python/tests/avro/test_encoder.py @@ -0,0 +1,207 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import datetime +import io +import struct +from decimal import Decimal + +from pyiceberg.avro.encoder import BinaryEncoder + + +def test_write() -> None: + output = io.BytesIO() + encoder = BinaryEncoder(output) + + _input = b"\x12\x34\x56" + + encoder.write(_input) + + assert output.getbuffer() == _input + + +def test_write_boolean() -> None: + output = io.BytesIO() + encoder = BinaryEncoder(output) + + encoder.write_boolean(True) + encoder.write_boolean(False) + + assert output.getbuffer() == struct.pack("??", True, False) + + +def test_write_int() -> None: + output = io.BytesIO() + encoder = BinaryEncoder(output) + + _1byte_input = 2 + _2byte_input = 7466 + _3byte_input = 523490 + _4byte_input = 86561570 + _5byte_input = 2510416930 + _6byte_input = 734929016866 + _7byte_input = 135081528772642 + _8byte_input = 35124861473277986 + + encoder.write_int(_1byte_input) + encoder.write_int(_2byte_input) + encoder.write_int(_3byte_input) + encoder.write_int(_4byte_input) + encoder.write_int(_5byte_input) + encoder.write_int(_6byte_input) + encoder.write_int(_7byte_input) + encoder.write_int(_8byte_input) + + buffer = output.getbuffer() + + assert buffer[0:1] == b"\x04" + assert buffer[1:3] == b"\xd4\x74" + assert buffer[3:6] == b"\xc4\xf3\x3f" + assert buffer[6:10] == b"\xc4\xcc\xc6\x52" + assert buffer[10:15] == b"\xc4\xb0\x8f\xda\x12" + assert buffer[15:21] == b"\xc4\xe0\xf6\xd2\xe3\x2a" + assert buffer[21:28] == b"\xc4\xa0\xce\xe8\xe3\xb6\x3d" + assert buffer[28:36] == b"\xc4\xa0\xb2\xae\x83\xf8\xe4\x7c" + + +def test_write_float() -> None: + output = io.BytesIO() + encoder = BinaryEncoder(output) + + _input = 3.14159265359 + + encoder.write_float(_input) + + assert output.getbuffer() == struct.pack(" None: + output = io.BytesIO() + encoder = BinaryEncoder(output) + + _input = 3.14159265359 + + encoder.write_double(_input) + + assert output.getbuffer() == struct.pack(" None: + output = io.BytesIO() + encoder = BinaryEncoder(output) + + _input = Decimal("3.14159265359") + + encoder.write_decimal_bytes(_input) + + assert output.getbuffer() == b"\x0a\x49\x25\x59\xf6\x4f" + + +def test_write_decimal_fixed() -> None: + output = io.BytesIO() + encoder = BinaryEncoder(output) + + _input = Decimal("3.14159265359") + + encoder.write_decimal_fixed(_input, 8) + + assert output.getbuffer() == b"\x00\x00\x00\x49\x25\x59\xf6\x4f" + + +def test_write_bytes() -> None: + output = io.BytesIO() + encoder = BinaryEncoder(output) + + _input = b"\x12\x34\x56" + + encoder.write_bytes(_input) + + assert output.getbuffer() == b"".join([b"\x06", _input]) + + +def test_write_bytes_fixed() -> None: + output = io.BytesIO() + encoder = BinaryEncoder(output) + + _input = b"\x12\x34\x56" + + encoder.write_bytes_fixed(_input) + + assert output.getbuffer() == _input + + +def test_write_utf8() -> None: + output = io.BytesIO() + encoder = BinaryEncoder(output) + + _input = "That, my liege, is how we know the Earth to be banana-shaped." + bin_input = _input.encode() + encoder.write_utf8(_input) + + assert output.getbuffer() == b"".join([b"\x7a", bin_input]) + + +def test_write_date_int() -> None: + output = io.BytesIO() + encoder = BinaryEncoder(output) + + _input = datetime.date(1970, 1, 2) + encoder.write_date_int(_input) + + assert output.getbuffer() == b"\x02" + + +def test_write_time_millis_int() -> None: + output = io.BytesIO() + encoder = BinaryEncoder(output) + + _input = datetime.time(1, 2, 3, 456000) + encoder.write_time_millis_int(_input) + + assert output.getbuffer() == b"\x80\xc3\xc6\x03" + + +def test_write_time_micros_long() -> None: + output = io.BytesIO() + encoder = BinaryEncoder(output) + + _input = datetime.time(1, 2, 3, 456000) + + encoder.write_time_micros_long(_input) + + assert output.getbuffer() == b"\x80\xb8\xfb\xde\x1b" + + +def test_write_timestamp_millis_long() -> None: + output = io.BytesIO() + encoder = BinaryEncoder(output) + + _input = datetime.datetime(2023, 1, 1, 1, 2, 3) + encoder.write_timestamp_millis_long(_input) + + assert output.getbuffer() == b"\xf0\xdb\xcc\xad\xad\x61" + + +def test_write_timestamp_micros_long() -> None: + output = io.BytesIO() + encoder = BinaryEncoder(output) + + _input = datetime.datetime(2023, 1, 1, 1, 2, 3) + encoder.write_timestamp_micros_long(_input) + + assert output.getbuffer() == b"\x80\xe3\xad\x9f\xac\xca\xf8\x05" diff --git a/python/tests/avro/test_file.py b/python/tests/avro/test_file.py index cdb973a39541..53d0216ab07c 100644 --- a/python/tests/avro/test_file.py +++ b/python/tests/avro/test_file.py @@ -14,10 +14,27 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from enum import Enum +from tempfile import TemporaryDirectory +from typing import Any + import pytest +from fastavro import reader, writer +import pyiceberg.avro.file as avro from pyiceberg.avro.codecs import DeflateCodec from pyiceberg.avro.file import META_SCHEMA, AvroFileHeader +from pyiceberg.io.pyarrow import PyArrowFileIO +from pyiceberg.manifest import ( + MANIFEST_ENTRY_SCHEMA, + DataFile, + DataFileContent, + FileFormat, + ManifestEntry, + ManifestEntryStatus, +) +from pyiceberg.typedef import Record +from pyiceberg.utils.schema_conversion import AvroSchemaConversion def get_deflate_compressor() -> None: @@ -58,3 +75,120 @@ def test_missing_schema() -> None: header.get_schema() assert "No schema found in Avro file headers" in str(exc_info.value) + + +# helper function to serialize our objects to dicts to enable +# direct comparison with the dicts returned by fastavro +def todict(obj: Any) -> Any: + if isinstance(obj, dict): + data = [] + for k, v in obj.items(): + data.append({"key": k, "value": v}) + return data + elif isinstance(obj, Enum): + return obj.value + elif hasattr(obj, "__iter__") and not isinstance(obj, str) and not isinstance(obj, bytes): + return [todict(v) for v in obj] + elif hasattr(obj, "__dict__"): + return {key: todict(value) for key, value in obj.__dict__.items() if not callable(value) and not key.startswith("_")} + else: + return obj + + +def test_write_manifest_entry_with_iceberg_read_with_fastavro() -> None: + data_file = DataFile( + content=DataFileContent.DATA, + file_path="s3://some-path/some-file.parquet", + file_format=FileFormat.PARQUET, + partition=Record(), + record_count=131327, + file_size_in_bytes=220669226, + column_sizes={1: 220661854}, + value_counts={1: 131327}, + null_value_counts={1: 0}, + nan_value_counts={}, + lower_bounds={1: b"aaaaaaaaaaaaaaaa"}, + upper_bounds={1: b"zzzzzzzzzzzzzzzz"}, + key_metadata=b"\xde\xad\xbe\xef", + split_offsets=[4, 133697593], + equality_ids=[], + sort_order_id=4, + spec_id=3, + ) + entry = ManifestEntry( + status=ManifestEntryStatus.ADDED, + snapshot_id=8638475580105682862, + data_sequence_number=0, + file_sequence_number=0, + data_file=data_file, + ) + + additional_metadata = {"foo": "bar"} + + with TemporaryDirectory() as tmpdir: + tmp_avro_file = tmpdir + "/manifest_entry.avro" + + with avro.AvroOutputFile[ManifestEntry]( + PyArrowFileIO().new_output(tmp_avro_file), MANIFEST_ENTRY_SCHEMA, "manifest_entry", additional_metadata + ) as out: + out.write_block([entry]) + + with open(tmp_avro_file, "rb") as fo: + r = reader(fo=fo) + + for k, v in additional_metadata.items(): + assert k in r.metadata + assert v == r.metadata[k] + + it = iter(r) + + fa_entry = next(it) + + assert todict(entry) == fa_entry + + +def test_write_manifest_entry_with_fastavro_read_with_iceberg() -> None: + data_file = DataFile( + content=DataFileContent.DATA, + file_path="s3://some-path/some-file.parquet", + file_format=FileFormat.PARQUET, + partition=Record(), + record_count=131327, + file_size_in_bytes=220669226, + column_sizes={1: 220661854}, + value_counts={1: 131327}, + null_value_counts={1: 0}, + nan_value_counts={}, + lower_bounds={1: b"aaaaaaaaaaaaaaaa"}, + upper_bounds={1: b"zzzzzzzzzzzzzzzz"}, + key_metadata=b"\xde\xad\xbe\xef", + split_offsets=[4, 133697593], + equality_ids=[], + sort_order_id=4, + spec_id=3, + ) + entry = ManifestEntry( + status=ManifestEntryStatus.ADDED, + snapshot_id=8638475580105682862, + data_sequence_number=0, + file_sequence_number=0, + data_file=data_file, + ) + + with TemporaryDirectory() as tmpdir: + tmp_avro_file = tmpdir + "/manifest_entry.avro" + + schema = AvroSchemaConversion().iceberg_to_avro(MANIFEST_ENTRY_SCHEMA, schema_name="manifest_entry") + + with open(tmp_avro_file, "wb") as out: + writer(out, schema, [todict(entry)]) + + with avro.AvroFile[ManifestEntry]( + PyArrowFileIO().new_input(tmp_avro_file), + MANIFEST_ENTRY_SCHEMA, + {-1: ManifestEntry, 2: DataFile}, + ) as avro_reader: + it = iter(avro_reader) + avro_entry = next(it) + + assert entry == avro_entry diff --git a/python/tests/avro/test_writer.py b/python/tests/avro/test_writer.py new file mode 100644 index 000000000000..c517a0cd1c4d --- /dev/null +++ b/python/tests/avro/test_writer.py @@ -0,0 +1,220 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint:disable=protected-access + +import io +import struct +from typing import Dict, List + +import pytest + +from pyiceberg.avro.encoder import BinaryEncoder +from pyiceberg.avro.resolver import construct_writer +from pyiceberg.avro.writer import ( + BinaryWriter, + BooleanWriter, + DateWriter, + DecimalWriter, + DoubleWriter, + FixedWriter, + FloatWriter, + IntegerWriter, + StringWriter, + TimestamptzWriter, + TimestampWriter, + TimeWriter, + UUIDWriter, +) +from pyiceberg.typedef import Record +from pyiceberg.types import ( + BinaryType, + BooleanType, + DateType, + DecimalType, + DoubleType, + FixedType, + FloatType, + IntegerType, + ListType, + LongType, + MapType, + NestedField, + PrimitiveType, + StringType, + StructType, + TimestampType, + TimestamptzType, + TimeType, + UUIDType, +) + + +def zigzag_encode(datum: int) -> bytes: + result = [] + datum = (datum << 1) ^ (datum >> 63) + while (datum & ~0x7F) != 0: + result.append(struct.pack("B", (datum & 0x7F) | 0x80)) + datum >>= 7 + result.append(struct.pack("B", datum)) + return b"".join(result) + + +def test_fixed_writer() -> None: + assert construct_writer(FixedType(22)) == FixedWriter(22) + + +def test_decimal_writer() -> None: + assert construct_writer(DecimalType(19, 25)) == DecimalWriter(19, 25) + + +def test_boolean_writer() -> None: + assert construct_writer(BooleanType()) == BooleanWriter() + + +def test_integer_writer() -> None: + assert construct_writer(IntegerType()) == IntegerWriter() + + +def test_long_writer() -> None: + assert construct_writer(LongType()) == IntegerWriter() + + +def test_float_writer() -> None: + assert construct_writer(FloatType()) == FloatWriter() + + +def test_double_writer() -> None: + assert construct_writer(DoubleType()) == DoubleWriter() + + +def test_date_writer() -> None: + assert construct_writer(DateType()) == DateWriter() + + +def test_time_writer() -> None: + assert construct_writer(TimeType()) == TimeWriter() + + +def test_timestamp_writer() -> None: + assert construct_writer(TimestampType()) == TimestampWriter() + + +def test_timestamptz_writer() -> None: + assert construct_writer(TimestamptzType()) == TimestamptzWriter() + + +def test_string_writer() -> None: + assert construct_writer(StringType()) == StringWriter() + + +def test_binary_writer() -> None: + assert construct_writer(BinaryType()) == BinaryWriter() + + +def test_unknown_type() -> None: + class UnknownType(PrimitiveType): + __root__ = "UnknownType" + + with pytest.raises(ValueError) as exc_info: + construct_writer(UnknownType()) + + assert "Unknown type:" in str(exc_info.value) + + +def test_uuid_writer() -> None: + assert construct_writer(UUIDType()) == UUIDWriter() + + +def test_write_simple_struct() -> None: + output = io.BytesIO() + encoder = BinaryEncoder(output) + + schema = StructType( + NestedField(1, "id", IntegerType(), required=True), NestedField(2, "property", StringType(), required=True) + ) + + class MyStruct(Record): + id: int + property: str + + my_struct = MyStruct(id=12, property="awesome") + + enc_str = b"awesome" + + construct_writer(schema).write(encoder, my_struct) + + assert output.getbuffer() == b"".join([b"\x18", zigzag_encode(len(enc_str)), enc_str]) + + +def test_write_struct_with_dict() -> None: + output = io.BytesIO() + encoder = BinaryEncoder(output) + + schema = StructType( + NestedField(1, "id", IntegerType(), required=True), + NestedField(2, "properties", MapType(3, IntegerType(), 4, IntegerType()), required=True), + ) + + class MyStruct(Record): + id: int + properties: Dict[int, int] + + my_struct = MyStruct(id=12, properties={1: 2, 3: 4}) + + construct_writer(schema).write(encoder, my_struct) + + assert output.getbuffer() == b"".join( + [ + b"\x18", + zigzag_encode(len(my_struct.properties)), + zigzag_encode(1), + zigzag_encode(2), + zigzag_encode(3), + zigzag_encode(4), + b"\x00", + ] + ) + + +def test_write_struct_with_list() -> None: + output = io.BytesIO() + encoder = BinaryEncoder(output) + + schema = StructType( + NestedField(1, "id", IntegerType(), required=True), + NestedField(2, "properties", ListType(3, IntegerType()), required=True), + ) + + class MyStruct(Record): + id: int + properties: List[int] + + my_struct = MyStruct(id=12, properties=[1, 2, 3, 4]) + + construct_writer(schema).write(encoder, my_struct) + + assert output.getbuffer() == b"".join( + [ + b"\x18", + zigzag_encode(len(my_struct.properties)), + zigzag_encode(1), + zigzag_encode(2), + zigzag_encode(3), + zigzag_encode(4), + b"\x00", + ] + ) diff --git a/python/tests/utils/test_schema_conversion.py b/python/tests/utils/test_schema_conversion.py index 6a8c5a28c75a..2c42c445e432 100644 --- a/python/tests/utils/test_schema_conversion.py +++ b/python/tests/utils/test_schema_conversion.py @@ -37,7 +37,7 @@ from pyiceberg.utils.schema_conversion import AvroSchemaConversion -def test_iceberg_to_avro(avro_schema_manifest_file_v1: Dict[str, Any]) -> None: +def test_avro_to_iceberg(avro_schema_manifest_file_v1: Dict[str, Any]) -> None: iceberg_schema = AvroSchemaConversion().avro_to_iceberg(avro_schema_manifest_file_v1) expected_iceberg_schema = Schema( NestedField( @@ -354,3 +354,17 @@ def test_logical_map_with_invalid_fields() -> None: AvroSchemaConversion()._convert_logical_map_type(avro_type) assert "Invalid key-value pair schema:" in str(exc_info.value) + + +def test_iceberg_to_avro_manifest_list(avro_schema_manifest_file_v1: Dict[str, Any]) -> None: + """Round trip the manifest list""" + iceberg_schema = AvroSchemaConversion().avro_to_iceberg(avro_schema_manifest_file_v1) + avro_result = AvroSchemaConversion().iceberg_to_avro(iceberg_schema, schema_name="manifest_file") + assert avro_schema_manifest_file_v1 == avro_result + + +def test_iceberg_to_avro_manifest(avro_schema_manifest_entry: Dict[str, Any]) -> None: + """Round trip the manifest itself""" + iceberg_schema = AvroSchemaConversion().avro_to_iceberg(avro_schema_manifest_entry) + avro_result = AvroSchemaConversion().iceberg_to_avro(iceberg_schema, schema_name="manifest_entry") + assert avro_schema_manifest_entry == avro_result