Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Python: Add support for initial default #7699

Merged
merged 3 commits into from
Jun 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion python/pyiceberg/avro/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import json
from dataclasses import dataclass
from enum import Enum
from types import TracebackType
from typing import (
Callable,
Expand Down Expand Up @@ -121,6 +122,7 @@ class AvroFile(Generic[D]):
input_file: InputFile
read_schema: Optional[Schema]
read_types: Dict[int, Callable[..., StructProtocol]]
read_enums: Dict[int, Callable[..., Enum]]
input_stream: InputStream
header: AvroFileHeader
schema: Schema
Expand All @@ -134,10 +136,12 @@ def __init__(
input_file: InputFile,
read_schema: Optional[Schema] = None,
read_types: Dict[int, Callable[..., StructProtocol]] = EMPTY_DICT,
read_enums: Dict[int, Callable[..., Enum]] = EMPTY_DICT,
) -> None:
self.input_file = input_file
self.read_schema = read_schema
self.read_types = read_types
self.read_enums = read_enums

def __enter__(self) -> AvroFile[D]:
"""
Expand All @@ -154,7 +158,7 @@ def __enter__(self) -> AvroFile[D]:
if not self.read_schema:
self.read_schema = self.schema

self.reader = resolve(self.schema, self.read_schema, self.read_types)
self.reader = resolve(self.schema, self.read_schema, self.read_types, self.read_enums)

return self

Expand Down
13 changes: 13 additions & 0 deletions python/pyiceberg/avro/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,19 @@ def skip(self, decoder: BinaryDecoder) -> None:
return None


class DefaultReader(Reader):
default_value: Any

def __init__(self, default_value: Any) -> None:
self.default_value = default_value

def read(self, _: BinaryDecoder) -> Any:
return self.default_value

def skip(self, decoder: BinaryDecoder) -> None:
pass


class BooleanReader(Reader):
def read(self, decoder: BinaryDecoder) -> bool:
return decoder.read_boolean()
Expand Down
59 changes: 50 additions & 9 deletions python/pyiceberg/avro/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=arguments-renamed,unused-argument
from enum import Enum
from typing import (
Callable,
Dict,
Expand All @@ -24,11 +25,13 @@
Union,
)

from pyiceberg.avro.decoder import BinaryDecoder
from pyiceberg.avro.reader import (
BinaryReader,
BooleanReader,
DateReader,
DecimalReader,
DefaultReader,
DoubleReader,
FixedReader,
FloatReader,
Expand Down Expand Up @@ -77,6 +80,8 @@
UUIDType,
)

STRUCT_ROOT = -1


def construct_reader(
file_schema: Union[Schema, IcebergType], read_types: Dict[int, Callable[..., StructProtocol]] = EMPTY_DICT
Expand All @@ -96,26 +101,53 @@ def resolve(
file_schema: Union[Schema, IcebergType],
read_schema: Union[Schema, IcebergType],
read_types: Dict[int, Callable[..., StructProtocol]] = EMPTY_DICT,
read_enums: Dict[int, Callable[..., Enum]] = EMPTY_DICT,
) -> Reader:
"""Resolves the file and read schema to produce a reader

Args:
file_schema (Schema | IcebergType): The schema of the Avro file
read_schema (Schema | IcebergType): The requested read schema which is equal, subset or superset of the file schema
read_types (Dict[int, Callable[[Schema], StructProtocol]]): A dict of types to use for struct data
read_types (Dict[int, Callable[..., StructProtocol]]): A dict of types to use for struct data
read_enums (Dict[int, Callable[..., Enum]]): A dict of fields that have to be converted to an enum

Raises:
NotImplementedError: If attempting to resolve an unrecognized object type
"""
return visit_with_partner(file_schema, read_schema, SchemaResolver(read_types), SchemaPartnerAccessor()) # type: ignore
return visit_with_partner(
file_schema, read_schema, SchemaResolver(read_types, read_enums), SchemaPartnerAccessor()
) # type: ignore


class EnumReader(Reader):
"""An Enum reader to wrap primitive values into an Enum"""

enum: Callable[..., Enum]
reader: Reader

def __init__(self, enum: Callable[..., Enum], reader: Reader) -> None:
self.enum = enum
self.reader = reader

def read(self, decoder: BinaryDecoder) -> Enum:
return self.enum(self.reader.read(decoder))

def skip(self, decoder: BinaryDecoder) -> None:
pass


class SchemaResolver(PrimitiveWithPartnerVisitor[IcebergType, Reader]):
read_types: Dict[int, Callable[..., StructProtocol]]
read_enums: Dict[int, Callable[..., Enum]]
context: List[int]

def __init__(self, read_types: Dict[int, Callable[..., StructProtocol]] = EMPTY_DICT) -> None:
def __init__(
self,
read_types: Dict[int, Callable[..., StructProtocol]] = EMPTY_DICT,
read_enums: Dict[int, Callable[..., Enum]] = EMPTY_DICT,
) -> None:
self.read_types = read_types
self.read_enums = read_enums
self.context = []

def schema(self, schema: Schema, expected_schema: Optional[IcebergType], result: Reader) -> Reader:
Expand All @@ -128,8 +160,7 @@ def after_field(self, field: NestedField, field_partner: Optional[NestedField])
self.context.pop()

def struct(self, struct: StructType, expected_struct: Optional[IcebergType], field_readers: List[Reader]) -> Reader:
# -1 indicates the struct root
read_struct_id = self.context[-1] if len(self.context) > 0 else -1
read_struct_id = self.context[STRUCT_ROOT] if len(self.context) > 0 else STRUCT_ROOT
struct_callable = self.read_types.get(read_struct_id, Record)

if not expected_struct:
Expand All @@ -142,16 +173,26 @@ def struct(self, struct: StructType, expected_struct: Optional[IcebergType], fie

# first, add readers for the file fields that must be in order
results: List[Tuple[Optional[int], Reader]] = [
(expected_positions.get(field.field_id), result_reader) for field, result_reader in zip(struct.fields, field_readers)
(
expected_positions.get(field.field_id),
# Check if we need to convert it to an Enum
result_reader if not (enum_type := self.read_enums.get(field.field_id)) else EnumReader(enum_type, result_reader),
)
for field, result_reader in zip(struct.fields, field_readers)
]

file_fields = {field.field_id: field for field in struct.fields}
for pos, read_field in enumerate(expected_struct.fields):
if read_field.field_id not in file_fields:
if read_field.required:
if isinstance(read_field, NestedField) and read_field.initial_default is not None:
# The field is not in the file, but there is a default value
# and that one can be required
results.append((pos, DefaultReader(read_field.initial_default)))
elif read_field.required:
raise ResolveError(f"{read_field} is non-optional, and not part of the file schema")
# Just set the new field to None
results.append((pos, NoneReader()))
rdblue marked this conversation as resolved.
Show resolved Hide resolved
else:
# Just set the new field to None
results.append((pos, NoneReader()))

return StructReader(tuple(results), struct_callable, expected_struct)

Expand Down
10 changes: 6 additions & 4 deletions python/pyiceberg/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,9 +229,9 @@ def __init__(self, *data: Any, **named_data: Any) -> None:
NestedField(500, "manifest_path", StringType(), required=True, doc="Location URI with FS scheme"),
NestedField(501, "manifest_length", LongType(), required=True),
NestedField(502, "partition_spec_id", IntegerType(), required=True),
NestedField(517, "content", IntegerType(), required=False),
NestedField(515, "sequence_number", LongType(), required=False),
NestedField(516, "min_sequence_number", LongType(), required=False),
NestedField(517, "content", IntegerType(), required=False, initial_default=ManifestContent.DATA),
NestedField(515, "sequence_number", LongType(), required=False, initial_default=0),
NestedField(516, "min_sequence_number", LongType(), required=False, initial_default=0),
NestedField(503, "added_snapshot_id", LongType(), required=False),
NestedField(504, "added_files_count", IntegerType(), required=False),
NestedField(505, "existing_files_count", IntegerType(), required=False),
Expand Down Expand Up @@ -283,5 +283,7 @@ def files(input_file: InputFile) -> Iterator[DataFile]:


def read_manifest_list(input_file: InputFile) -> Iterator[ManifestFile]:
with AvroFile[ManifestFile](input_file, MANIFEST_FILE_SCHEMA, {-1: ManifestFile, 508: PartitionFieldSummary}) as reader:
with AvroFile[ManifestFile](
input_file, MANIFEST_FILE_SCHEMA, {-1: ManifestFile, 508: PartitionFieldSummary}, {517: ManifestContent}
) as reader:
yield from reader
3 changes: 3 additions & 0 deletions python/pyiceberg/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ class NestedField(IcebergType):
field_type: IcebergType = Field(alias="type")
required: bool = Field(default=True)
doc: Optional[str] = Field(default=None, repr=False)
initial_default: Any = Field(alias="initial-default", repr=False)

def __init__(
self,
Expand All @@ -227,6 +228,7 @@ def __init__(
field_type: Optional[IcebergType] = None,
required: bool = True,
doc: Optional[str] = None,
initial_default: Optional[Any] = None,
**data: Any,
):
# We need an init when we want to use positional arguments, but
Expand All @@ -236,6 +238,7 @@ def __init__(
data["field_type"] = data["type"] if "type" in data else field_type
data["required"] = required
data["doc"] = doc
data["initial_default"] = initial_default
super().__init__(**data)

def __str__(self) -> str:
Expand Down
21 changes: 21 additions & 0 deletions python/tests/avro/test_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from pyiceberg.avro.file import AvroFile
from pyiceberg.avro.reader import (
DecimalReader,
DefaultReader,
DoubleReader,
FloatReader,
IntegerReader,
Expand Down Expand Up @@ -280,3 +281,23 @@ class Ints(Record):
records = list(reader)

assert repr(records) == "[Ints[c=3, d=None]]"


def test_resolver_initial_value() -> None:
write_schema = Schema(
NestedField(1, "name", StringType()),
schema_id=1,
)
read_schema = Schema(
NestedField(2, "something", StringType(), required=False, initial_default="vo"),
schema_id=2,
)

assert resolve(write_schema, read_schema) == StructReader(
(
(None, StringReader()), # The one we skip
(0, DefaultReader("vo")),
),
Record,
read_schema.as_struct(),
)
Loading