diff --git a/python/src/iceberg/avro/decoder.py b/python/src/iceberg/avro/decoder.py index 586aabf97814..702ad9d4daf5 100644 --- a/python/src/iceberg/avro/decoder.py +++ b/python/src/iceberg/avro/decoder.py @@ -17,6 +17,7 @@ import decimal import struct from datetime import date, datetime, time +from io import SEEK_CUR from iceberg.io.base import InputStream from iceberg.utils.datetime import ( @@ -56,6 +57,9 @@ def read(self, n: int) -> bytes: raise ValueError(f"Read {len(read_bytes)} bytes, expected {n} bytes") return read_bytes + def skip(self, n: int) -> None: + self._input_stream.seek(n, SEEK_CUR) + def read_boolean(self) -> bool: """ a boolean is written as a single byte @@ -64,11 +68,7 @@ def read_boolean(self) -> bool: return ord(self.read(1)) == 1 def read_int(self) -> int: - """int values are written using variable-length, zigzag coding.""" - return self.read_long() - - def read_long(self) -> int: - """long values are written using variable-length, zigzag coding.""" + """int/long values are written using variable-length, zigzag coding.""" b = ord(self.read(1)) n = b & 0x7F shift = 7 @@ -100,7 +100,7 @@ def read_decimal_from_bytes(self, precision: int, scale: int) -> decimal.Decimal Decimal bytes are decoded as signed short, int or long depending on the size of bytes. """ - size = self.read_long() + size = self.read_int() return self.read_decimal_from_fixed(precision, scale, size) def read_decimal_from_fixed(self, _: int, scale: int, size: int) -> decimal.Decimal: @@ -116,7 +116,7 @@ def read_bytes(self) -> bytes: """ Bytes are encoded as a long followed by that many bytes of data. """ - return self.read(self.read_long()) + return self.read(self.read_int()) def read_utf8(self) -> str: """ @@ -146,14 +146,14 @@ def read_time_micros(self) -> time: long is decoded as python time object which represents the number of microseconds after midnight, 00:00:00.000000. """ - return micros_to_time(self.read_long()) + return micros_to_time(self.read_int()) def read_timestamp_micros(self) -> datetime: """ long is decoded as python datetime object which represents the number of microseconds from the unix epoch, 1 January 1970. """ - return micros_to_timestamp(self.read_long()) + return micros_to_timestamp(self.read_int()) def read_timestamptz_micros(self): """ @@ -162,4 +162,24 @@ def read_timestamptz_micros(self): Adjusted to UTC """ - return micros_to_timestamptz(self.read_long()) + return micros_to_timestamptz(self.read_int()) + + def skip_boolean(self) -> None: + self.skip(1) + + def skip_int(self) -> None: + b = ord(self.read(1)) + while (b & 0x80) != 0: + b = ord(self.read(1)) + + def skip_float(self) -> None: + self.skip(4) + + def skip_double(self) -> None: + self.skip(8) + + def skip_bytes(self) -> None: + self.skip(self.read_int()) + + def skip_utf8(self) -> None: + self.skip_bytes() diff --git a/python/src/iceberg/avro/file.py b/python/src/iceberg/avro/file.py index 0eec227e9543..585c74c07c5e 100644 --- a/python/src/iceberg/avro/file.py +++ b/python/src/iceberg/avro/file.py @@ -27,6 +27,7 @@ from iceberg.avro.codecs import KNOWN_CODECS, Codec from iceberg.avro.decoder import BinaryDecoder from iceberg.avro.reader import AvroStruct, ConstructReader, StructReader +from iceberg.avro.resolver import resolve from iceberg.io.base import InputFile, InputStream from iceberg.io.memory import MemoryInputStream from iceberg.schema import Schema, visit @@ -107,6 +108,7 @@ def __next__(self) -> AvroStruct: class AvroFile: input_file: InputFile + read_schema: Schema | None input_stream: InputStream header: AvroFileHeader schema: Schema @@ -116,8 +118,9 @@ class AvroFile: decoder: BinaryDecoder block: Block | None = None - def __init__(self, input_file: InputFile) -> None: + def __init__(self, input_file: InputFile, read_schema: Schema | None = None) -> None: self.input_file = input_file + self.read_schema = read_schema def __enter__(self): """ @@ -132,7 +135,11 @@ def __enter__(self): self.header = self._read_header() self.schema = self.header.get_schema() self.file_length = len(self.input_file) - self.reader = visit(self.schema, ConstructReader()) + if not self.read_schema: + self.reader = visit(self.schema, ConstructReader()) + else: + self.reader = resolve(self.schema, self.read_schema) + return self def __exit__(self, exc_type, exc_val, exc_tb): @@ -149,9 +156,9 @@ def _read_block(self) -> int: raise ValueError(f"Expected sync bytes {self.header.sync!r}, but got {sync_marker!r}") if self.is_EOF(): raise StopIteration - block_records = self.decoder.read_long() + block_records = self.decoder.read_int() - block_bytes_len = self.decoder.read_long() + block_bytes_len = self.decoder.read_int() block_bytes = self.decoder.read(block_bytes_len) if codec := self.header.compression_codec(): block_bytes = codec.decompress(block_bytes) diff --git a/python/src/iceberg/avro/reader.py b/python/src/iceberg/avro/reader.py index 012d611a1a51..c122c8a780e3 100644 --- a/python/src/iceberg/avro/reader.py +++ b/python/src/iceberg/avro/reader.py @@ -31,7 +31,7 @@ from datetime import date, datetime, time from decimal import Decimal from functools import singledispatch -from typing import Any +from typing import Any, Callable from uuid import UUID from iceberg.avro.decoder import BinaryDecoder @@ -60,6 +60,42 @@ from iceberg.utils.singleton import Singleton +def _skip_map_array(decoder: BinaryDecoder, skip_entry: Callable) -> None: + """Skips over an array or map + + Both the array and map are encoded similar, and we can re-use + the logic of skipping in an efficient way. + + From the Avro spec: + + Maps (and arrays) are encoded as a series of blocks. + Each block consists of a long count value, followed by that many key/value pairs in the case of a map, + and followed by that many array items in the case of an array. A block with count zero indicates the + end of the map. Each item is encoded per the map's value schema. + + If a block's count is negative, its absolute value is used, and the count is followed immediately by a + long block size indicating the number of bytes in the block. This block size permits fast skipping + through data, e.g., when projecting a record to a subset of its fields. + + Args: + decoder: + The decoder that reads the types from the underlying data + skip_entry: + Function to skip over the underlying data, element in case of an array, and the + key/value in the case of a map + """ + block_count = decoder.read_int() + while block_count != 0: + if block_count < 0: + # The length in bytes in encoded, so we can skip over it right away + block_size = decoder.read_int() + decoder.skip(block_size) + else: + for _ in range(block_count): + skip_entry() + block_count = decoder.read_int() + + @dataclass(frozen=True) class AvroStruct(StructProtocol): _data: list[Any | StructProtocol] = dataclassfield() @@ -76,66 +112,100 @@ class Reader(Singleton): def read(self, decoder: BinaryDecoder) -> Any: ... + @abstractmethod + def skip(self, decoder: BinaryDecoder) -> None: + ... + class NoneReader(Reader): def read(self, _: BinaryDecoder) -> None: return None + def skip(self, decoder: BinaryDecoder) -> None: + return None + class BooleanReader(Reader): def read(self, decoder: BinaryDecoder) -> bool: return decoder.read_boolean() + def skip(self, decoder: BinaryDecoder) -> None: + decoder.skip_boolean() + class IntegerReader(Reader): + """Longs and ints are encoded the same way, and there is no long in Python""" + def read(self, decoder: BinaryDecoder) -> int: return decoder.read_int() - -class LongReader(Reader): - def read(self, decoder: BinaryDecoder) -> int: - return decoder.read_long() + def skip(self, decoder: BinaryDecoder) -> None: + decoder.skip_int() class FloatReader(Reader): def read(self, decoder: BinaryDecoder) -> float: return decoder.read_float() + def skip(self, decoder: BinaryDecoder) -> None: + decoder.skip_float() + class DoubleReader(Reader): def read(self, decoder: BinaryDecoder) -> float: return decoder.read_double() + def skip(self, decoder: BinaryDecoder) -> None: + decoder.skip_double() + class DateReader(Reader): def read(self, decoder: BinaryDecoder) -> date: return decoder.read_date_from_int() + def skip(self, decoder: BinaryDecoder) -> None: + decoder.skip_int() + class TimeReader(Reader): def read(self, decoder: BinaryDecoder) -> time: return decoder.read_time_micros() + def skip(self, decoder: BinaryDecoder) -> None: + decoder.skip_int() + class TimestampReader(Reader): def read(self, decoder: BinaryDecoder) -> datetime: return decoder.read_timestamp_micros() + def skip(self, decoder: BinaryDecoder) -> None: + decoder.skip_int() + class TimestamptzReader(Reader): def read(self, decoder: BinaryDecoder) -> datetime: return decoder.read_timestamptz_micros() + def skip(self, decoder: BinaryDecoder) -> None: + decoder.skip_int() + class StringReader(Reader): def read(self, decoder: BinaryDecoder) -> str: return decoder.read_utf8() + def skip(self, decoder: BinaryDecoder) -> None: + decoder.skip_utf8() + class UUIDReader(Reader): def read(self, decoder: BinaryDecoder) -> UUID: return UUID(decoder.read_utf8()) + def skip(self, decoder: BinaryDecoder) -> None: + decoder.skip_utf8() + @dataclass(frozen=True) class FixedReader(Reader): @@ -144,11 +214,17 @@ class FixedReader(Reader): def read(self, decoder: BinaryDecoder) -> bytes: return decoder.read(self.length) + def skip(self, decoder: BinaryDecoder) -> None: + decoder.skip(self.length) + class BinaryReader(Reader): def read(self, decoder: BinaryDecoder) -> bytes: return decoder.read_bytes() + def skip(self, decoder: BinaryDecoder) -> None: + decoder.skip_bytes() + @dataclass(frozen=True) class DecimalReader(Reader): @@ -158,6 +234,9 @@ class DecimalReader(Reader): def read(self, decoder: BinaryDecoder) -> Decimal: return decoder.read_decimal_from_bytes(self.precision, self.scale) + def skip(self, decoder: BinaryDecoder) -> None: + decoder.skip_bytes() + @dataclass(frozen=True) class OptionReader(Reader): @@ -177,13 +256,28 @@ def read(self, decoder: BinaryDecoder) -> Any | None: return self.option.read(decoder) return None + def skip(self, decoder: BinaryDecoder) -> None: + if decoder.read_int() > 0: + return self.option.skip(decoder) + @dataclass(frozen=True) class StructReader(Reader): - fields: tuple[Reader, ...] = dataclassfield() + fields: tuple[tuple[int | None, Reader], ...] = dataclassfield() def read(self, decoder: BinaryDecoder) -> AvroStruct: - return AvroStruct([field.read(decoder) for field in self.fields]) + result: list[Any | StructProtocol] = [None] * len(self.fields) + for (pos, field) in self.fields: + if pos is not None: + result[pos] = field.read(decoder) + else: + field.skip(decoder) + + return AvroStruct(result) + + def skip(self, decoder: BinaryDecoder) -> None: + for _, field in self.fields: + field.skip(decoder) @dataclass(frozen=True) @@ -192,17 +286,19 @@ class ListReader(Reader): def read(self, decoder: BinaryDecoder) -> list: read_items = [] - block_count = decoder.read_long() + block_count = decoder.read_int() while block_count != 0: if block_count < 0: block_count = -block_count - # We ignore the block size for now - _ = decoder.read_long() + _ = decoder.read_int() for _ in range(block_count): read_items.append(self.element.read(decoder)) - block_count = decoder.read_long() + block_count = decoder.read_int() return read_items + def skip(self, decoder: BinaryDecoder) -> None: + _skip_map_array(decoder, lambda: self.element.skip(decoder)) + @dataclass(frozen=True) class MapReader(Reader): @@ -211,26 +307,33 @@ class MapReader(Reader): def read(self, decoder: BinaryDecoder) -> dict: read_items = {} - block_count = decoder.read_long() + block_count = decoder.read_int() while block_count != 0: if block_count < 0: block_count = -block_count # We ignore the block size for now - _ = decoder.read_long() + _ = decoder.read_int() for _ in range(block_count): key = self.key.read(decoder) read_items[key] = self.value.read(decoder) - block_count = decoder.read_long() + block_count = decoder.read_int() return read_items + def skip(self, decoder: BinaryDecoder) -> None: + def skip(): + self.key.skip(decoder) + self.value.skip(decoder) + + _skip_map_array(decoder, skip) + class ConstructReader(SchemaVisitor[Reader]): def schema(self, schema: Schema, struct_result: Reader) -> Reader: return struct_result def struct(self, struct: StructType, field_results: list[Reader]) -> Reader: - return StructReader(tuple(field_results)) + return StructReader(tuple(enumerate(field_results))) def field(self, field: NestedField, field_result: Reader) -> Reader: return field_result if field.required else OptionReader(field_result) @@ -274,7 +377,9 @@ def _(_: IntegerType) -> Reader: @primitive_reader.register(LongType) def _(_: LongType) -> Reader: - return LongReader() + # Ints and longs are encoded the same way in Python and + # also binary compatible in Avro + return IntegerReader() @primitive_reader.register(FloatType) diff --git a/python/src/iceberg/avro/resolver.py b/python/src/iceberg/avro/resolver.py new file mode 100644 index 000000000000..11017742abfa --- /dev/null +++ b/python/src/iceberg/avro/resolver.py @@ -0,0 +1,195 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from functools import singledispatch +from typing import ( + List, + Optional, + Tuple, + Union, +) + +from iceberg.avro.reader import ( + ConstructReader, + ListReader, + MapReader, + NoneReader, + OptionReader, + Reader, + StructReader, + primitive_reader, +) +from iceberg.schema import Schema, visit +from iceberg.types import ( + BinaryType, + DecimalType, + DoubleType, + FloatType, + IcebergType, + IntegerType, + ListType, + LongType, + MapType, + PrimitiveType, + StringType, + StructType, +) + + +class ResolveException(Exception): + pass + + +@singledispatch +def resolve(file_schema: Union[Schema, IcebergType], read_schema: Union[Schema, IcebergType]) -> Reader: + """This resolves the file and read schema + + The function traverses the schema in post-order fashion + + Args: + file_schema (Schema | IcebergType): The schema of the Avro file + read_schema (Schema | IcebergType): The requested read schema which is equal, subset or superset of the file schema + + Raises: + NotImplementedError: If attempting to resolve an unrecognized object type + """ + raise NotImplementedError(f"Cannot resolve non-type: {file_schema}") + + +@resolve.register(Schema) +def _(file_schema: Schema, read_schema: Schema) -> Reader: + """Visit a Schema and starts resolving it by converting it to a struct""" + return resolve(file_schema.as_struct(), read_schema.as_struct()) + + +@resolve.register(StructType) +def _(file_struct: StructType, read_struct: IcebergType) -> Reader: + """Iterates over the file schema, and checks if the field is in the read schema""" + + if not isinstance(read_struct, StructType): + raise ResolveException(f"File/read schema are not aligned for {file_struct}, got {read_struct}") + + results: List[Tuple[Optional[int], Reader]] = [] + read_fields = {field.field_id: (pos, field) for pos, field in enumerate(read_struct.fields)} + + for file_field in file_struct.fields: + if file_field.field_id in read_fields: + read_pos, read_field = read_fields[file_field.field_id] + result_reader = resolve(file_field.field_type, read_field.field_type) + else: + read_pos = None + result_reader = visit(file_field.field_type, ConstructReader()) + result_reader = result_reader if file_field.required else OptionReader(result_reader) + results.append((read_pos, result_reader)) + + file_fields = {field.field_id: field for field in file_struct.fields} + for pos, read_field in enumerate(read_struct.fields): + if read_field.field_id not in file_fields: + if read_field.required: + raise ResolveException(f"{read_field} is non-optional, and not part of the file schema") + # Just set the new field to None + results.append((pos, NoneReader())) + + return StructReader(tuple(results)) + + +@resolve.register(ListType) +def _(file_list: ListType, read_list: IcebergType) -> Reader: + if not isinstance(read_list, ListType): + raise ResolveException(f"File/read schema are not aligned for {file_list}, got {read_list}") + element_reader = resolve(file_list.element_type, read_list.element_type) + return ListReader(element_reader) + + +@resolve.register(MapType) +def _(file_map: MapType, read_map: IcebergType) -> Reader: + if not isinstance(read_map, MapType): + raise ResolveException(f"File/read schema are not aligned for {file_map}, got {read_map}") + key_reader = resolve(file_map.key_type, read_map.key_type) + value_reader = resolve(file_map.value_type, read_map.value_type) + + return MapReader(key_reader, value_reader) + + +@resolve.register(PrimitiveType) +def _(file_type: PrimitiveType, read_type: IcebergType) -> Reader: + """Converting the primitive type into an actual reader that will decode the physical data""" + if not isinstance(read_type, PrimitiveType): + raise ResolveException(f"Cannot promote {file_type} to {read_type}") + + # In the case of a promotion, we want to check if it is valid + if file_type != read_type: + return promote(file_type, read_type) + return primitive_reader(read_type) + + +@singledispatch +def promote(file_type: IcebergType, read_type: IcebergType) -> Reader: + """Promotes reading a file type to a read type + + Args: + file_type (IcebergType): The type of the Avro file + read_type (IcebergType): The requested read type + + Raises: + ResolveException: If attempting to resolve an unrecognized object type + """ + raise ResolveException(f"Cannot promote {file_type} to {read_type}") + + +@promote.register(IntegerType) +def _(file_type: IntegerType, read_type: IcebergType) -> Reader: + if isinstance(read_type, LongType): + # Ints/Longs are binary compatible in Avro, so this is okay + return primitive_reader(read_type) + else: + raise ResolveException(f"Cannot promote an int to {read_type}") + + +@promote.register(FloatType) +def _(file_type: FloatType, read_type: IcebergType) -> Reader: + if isinstance(read_type, DoubleType): + # We should just read the float, and return it, since it both returns a float + return primitive_reader(file_type) + else: + raise ResolveException(f"Cannot promote an float to {read_type}") + + +@promote.register(StringType) +def _(file_type: StringType, read_type: IcebergType) -> Reader: + if isinstance(read_type, BinaryType): + return primitive_reader(read_type) + else: + raise ResolveException(f"Cannot promote an string to {read_type}") + + +@promote.register(BinaryType) +def _(file_type: BinaryType, read_type: IcebergType) -> Reader: + if isinstance(read_type, StringType): + return primitive_reader(read_type) + else: + raise ResolveException(f"Cannot promote an binary to {read_type}") + + +@promote.register(DecimalType) +def _(file_type: DecimalType, read_type: IcebergType) -> Reader: + if isinstance(read_type, DecimalType): + if file_type.precision <= read_type.precision and file_type.scale == file_type.scale: + return primitive_reader(read_type) + else: + raise ResolveException(f"Cannot reduce precision from {file_type} to {read_type}") + else: + raise ResolveException(f"Cannot promote an decimal to {read_type}") diff --git a/python/src/iceberg/exceptions.py b/python/src/iceberg/exceptions.py index b12e836e46ce..f9ac3333b109 100644 --- a/python/src/iceberg/exceptions.py +++ b/python/src/iceberg/exceptions.py @@ -33,4 +33,4 @@ class AlreadyExistsError(Exception): class ValidationError(Exception): - ... + """Raises when there is an issue with the schema""" diff --git a/python/tests/avro/test_decoder.py b/python/tests/avro/test_decoder.py index 295715115183..b988cb6c4970 100644 --- a/python/tests/avro/test_decoder.py +++ b/python/tests/avro/test_decoder.py @@ -21,8 +21,10 @@ import pytest from iceberg.avro.decoder import BinaryDecoder +from iceberg.avro.resolver import promote from iceberg.io.base import InputStream from iceberg.io.memory import MemoryInputStream +from iceberg.types import DoubleType, FloatType def test_read_decimal_from_fixed(): @@ -33,10 +35,38 @@ def test_read_decimal_from_fixed(): assert actual == expected -def test_read_long(): +def test_read_boolean_true(): + mis = MemoryInputStream(b"\x01") + decoder = BinaryDecoder(mis) + assert decoder.read_boolean() is True + + +def test_read_boolean_false(): + mis = MemoryInputStream(b"\x00") + decoder = BinaryDecoder(mis) + assert decoder.read_boolean() is False + + +def test_skip_boolean(): + mis = MemoryInputStream(b"\x00") + decoder = BinaryDecoder(mis) + assert mis.tell() == 0 + decoder.skip_boolean() + assert mis.tell() == 1 + + +def test_read_int(): mis = MemoryInputStream(b"\x18") decoder = BinaryDecoder(mis) - assert decoder.read_long() == 12 + assert decoder.read_int() == 12 + + +def test_skip_int(): + mis = MemoryInputStream(b"\x18") + decoder = BinaryDecoder(mis) + assert mis.tell() == 0 + decoder.skip_int() + assert mis.tell() == 1 def test_read_decimal(): @@ -104,12 +134,28 @@ def test_read_float(): assert decoder.read_float() == 19.25 +def test_skip_float(): + mis = MemoryInputStream(b"\x00\x00\x9A\x41") + decoder = BinaryDecoder(mis) + assert mis.tell() == 0 + decoder.skip_float() + assert mis.tell() == 4 + + def test_read_double(): mis = MemoryInputStream(b"\x00\x00\x00\x00\x00\x40\x33\x40") decoder = BinaryDecoder(mis) assert decoder.read_double() == 19.25 +def test_skip_double(): + mis = MemoryInputStream(b"\x00\x00\x00\x00\x00\x40\x33\x40") + decoder = BinaryDecoder(mis) + assert mis.tell() == 0 + decoder.skip_double() + assert mis.tell() == 8 + + def test_read_date(): mis = MemoryInputStream(b"\xBC\x7D") decoder = BinaryDecoder(mis) @@ -138,3 +184,32 @@ def test_read_timestamptz_micros(): mis = MemoryInputStream(b"\xBC\x7D") decoder = BinaryDecoder(mis) assert decoder.read_timestamptz_micros() == datetime(1970, 1, 1, 0, 0, 0, 8030, tzinfo=timezone.utc) + + +def test_read_bytes(): + mis = MemoryInputStream(b"\x08\x01\x02\x03\x04") + decoder = BinaryDecoder(mis) + actual = decoder.read_bytes() + assert actual == b"\x01\x02\x03\x04" + + +def test_read_utf8(): + mis = MemoryInputStream(b"\x04\x76\x6F") + decoder = BinaryDecoder(mis) + assert decoder.read_utf8() == "vo" + + +def test_skip_utf8(): + mis = MemoryInputStream(b"\x04\x76\x6F") + decoder = BinaryDecoder(mis) + assert mis.tell() == 0 + decoder.skip_utf8() + assert mis.tell() == 3 + + +def test_read_int_as_float(): + mis = MemoryInputStream(b"\x00\x00\x9A\x41") + decoder = BinaryDecoder(mis) + reader = promote(FloatType(), DoubleType()) + + assert reader.read(decoder) == 19.25 diff --git a/python/tests/avro/test_reader.py b/python/tests/avro/test_reader.py index 6ea8377e1104..96afa390f8a1 100644 --- a/python/tests/avro/test_reader.py +++ b/python/tests/avro/test_reader.py @@ -30,7 +30,6 @@ FixedReader, FloatReader, IntegerReader, - LongReader, StringReader, TimeReader, TimestampReader, @@ -411,7 +410,7 @@ def test_integer_reader(): def test_long_reader(): - assert primitive_reader(LongType()) == LongReader() + assert primitive_reader(LongType()) == IntegerReader() def test_float_reader(): diff --git a/python/tests/avro/test_resolver.py b/python/tests/avro/test_resolver.py new file mode 100644 index 000000000000..197dc95246ee --- /dev/null +++ b/python/tests/avro/test_resolver.py @@ -0,0 +1,185 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +from iceberg.avro.reader import ( + DecimalReader, + DoubleReader, + FloatReader, + IntegerReader, + MapReader, + StringReader, + StructReader, +) +from iceberg.avro.resolver import ResolveException, promote, resolve +from iceberg.schema import Schema +from iceberg.types import ( + BinaryType, + DecimalType, + DoubleType, + FloatType, + IntegerType, + ListType, + LongType, + MapType, + NestedField, + StringType, + StructType, +) + + +def test_resolver(): + write_schema = Schema( + NestedField(1, "id", LongType()), + NestedField(2, "data", StringType()), + NestedField( + 3, + "location", + StructType( + NestedField(4, "lat", DoubleType()), + NestedField(5, "long", DoubleType()), + ), + ), + NestedField(6, "preferences", MapType(7, StringType(), 8, StringType())), + schema_id=1, + ) + read_schema = Schema( + NestedField( + 3, + "location", + StructType( + NestedField(4, "lat", DoubleType()), + NestedField(5, "long", DoubleType()), + ), + ), + NestedField(1, "id", LongType()), + NestedField(6, "preferences", MapType(7, StringType(), 8, StringType())), + schema_id=1, + ) + read_tree = resolve(write_schema, read_schema) + + assert read_tree == StructReader( + ( + (1, IntegerReader()), + (None, StringReader()), + ( + 0, + StructReader( + ( + (0, DoubleReader()), + (1, DoubleReader()), + ) + ), + ), + (2, MapReader(StringReader(), StringReader())), + ) + ) + + +def test_resolver_new_required_field(): + write_schema = Schema( + NestedField(1, "id", LongType()), + schema_id=1, + ) + read_schema = Schema( + NestedField(1, "id", LongType()), + NestedField(2, "data", StringType(), required=True), + schema_id=1, + ) + + with pytest.raises(ResolveException) as exc_info: + resolve(write_schema, read_schema) + + assert "2: data: required string is non-optional, and not part of the file schema" in str(exc_info.value) + + +def test_resolver_invalid_evolution(): + write_schema = Schema( + NestedField(1, "id", LongType()), + schema_id=1, + ) + read_schema = Schema( + NestedField(1, "id", DoubleType()), + schema_id=1, + ) + + with pytest.raises(ResolveException) as exc_info: + resolve(write_schema, read_schema) + + assert "Cannot promote long to double" in str(exc_info.value) + + +def test_resolver_promotion_string_to_binary(): + write_schema = Schema( + NestedField(1, "id", StringType()), + schema_id=1, + ) + read_schema = Schema( + NestedField(1, "id", BinaryType()), + schema_id=1, + ) + resolve(write_schema, read_schema) + + +def test_resolver_promotion_binary_to_string(): + write_schema = Schema( + NestedField(1, "id", BinaryType()), + schema_id=1, + ) + read_schema = Schema( + NestedField(1, "id", StringType()), + schema_id=1, + ) + resolve(write_schema, read_schema) + + +def test_resolver_change_type(): + write_schema = Schema( + NestedField(1, "properties", ListType(2, StringType())), + schema_id=1, + ) + read_schema = Schema( + NestedField(1, "properties", MapType(2, StringType(), 3, StringType())), + schema_id=1, + ) + + with pytest.raises(ResolveException) as exc_info: + resolve(write_schema, read_schema) + + assert "File/read schema are not aligned for list, got map" in str(exc_info.value) + + +def test_promote_int_to_long(): + assert promote(IntegerType(), LongType()) == IntegerReader() + + +def test_promote_float_to_double(): + # We should still read floats, because it is encoded in 4 bytes + assert promote(FloatType(), DoubleType()) == FloatReader() + + +def test_promote_decimal_to_decimal(): + # DecimalType(P, S) to DecimalType(P2, S) where P2 > P + assert promote(DecimalType(19, 25), DecimalType(22, 25)) == DecimalReader(22, 25) + + +def test_promote_decimal_to_decimal_reduce_precision(): + # DecimalType(P, S) to DecimalType(P2, S) where P2 > P + with pytest.raises(ResolveException) as exc_info: + _ = promote(DecimalType(19, 25), DecimalType(10, 25)) == DecimalReader(22, 25) + + assert "Cannot reduce precision from decimal(19, 25) to decimal(10, 25)" in str(exc_info.value)