Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion python/src/iceberg/expressions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,12 @@ def bind(self, schema: Schema, case_sensitive: bool) -> BoundReference:
if not field:
raise ValueError(f"Cannot find field '{self.name}' in schema: {schema}")

return BoundReference(field=field, accessor=schema.accessor_for_field(field.field_id))
accessor = schema.accessor_for_field(field.field_id)

if not accessor:
raise ValueError(f"Cannot find accessor for field '{self.name}' in schema: {schema}")

return BoundReference(field=field, accessor=accessor)


class BooleanExpressionVisitor(Generic[T], ABC):
Expand Down
45 changes: 19 additions & 26 deletions python/src/iceberg/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from abc import ABC, abstractmethod
from dataclasses import dataclass
from functools import singledispatch
from functools import cached_property, singledispatch
from typing import (
Any,
Dict,
Expand Down Expand Up @@ -54,10 +54,6 @@ def __init__(self, *columns: NestedField, schema_id: int, identifier_field_ids:
self._schema_id = schema_id
self._identifier_field_ids = identifier_field_ids or []
self._name_to_id: dict[str, int] = index_by_name(self)
self._name_to_id_lower: dict[str, int] = {} # Should be accessed through self._lazy_name_to_id_lower()
self._id_to_field: dict[int, NestedField] = {} # Should be accessed through self._lazy_id_to_field()
self._id_to_name: dict[int, str] = {} # Should be accessed through self._lazy_id_to_name()
self._id_to_accessor: dict[int, Accessor] = {} # Should be accessed through self._lazy_id_to_accessor()

def __str__(self):
return "table {\n" + "\n".join([" " + str(field) for field in self.columns]) + "\n}"
Expand Down Expand Up @@ -96,47 +92,43 @@ def schema_id(self) -> int:
def identifier_field_ids(self) -> list[int]:
return self._identifier_field_ids

@cached_property
def _lazy_id_to_field(self) -> dict[int, NestedField]:
"""Returns an index of field ID to NestedField instance

This is calculated once when called for the first time. Subsequent calls to this method will use a cached index.
"""
if not self._id_to_field:
self._id_to_field = index_by_id(self)
return self._id_to_field
return index_by_id(self)

@cached_property
def _lazy_name_to_id_lower(self) -> dict[str, int]:
"""Returns an index of lower-case field names to field IDs

This is calculated once when called for the first time. Subsequent calls to this method will use a cached index.
"""
if not self._name_to_id_lower:
self._name_to_id_lower = {name.lower(): field_id for name, field_id in self._name_to_id.items()}
return self._name_to_id_lower
return {name.lower(): field_id for name, field_id in self._name_to_id.items()}

@cached_property
def _lazy_id_to_name(self) -> dict[int, str]:
"""Returns an index of field ID to full name

This is calculated once when called for the first time. Subsequent calls to this method will use a cached index.
"""
if not self._id_to_name:
self._id_to_name = index_name_by_id(self)
return self._id_to_name
return index_name_by_id(self)

@cached_property
def _lazy_id_to_accessor(self) -> dict[int, Accessor]:
"""Returns an index of field ID to accessor

This is calculated once when called for the first time. Subsequent calls to this method will use a cached index.
"""
if not self._id_to_accessor:
self._id_to_accessor = build_position_accessors(self)
return self._id_to_accessor
return build_position_accessors(self)

def as_struct(self) -> StructType:
"""Returns the underlying struct"""
return self._struct

def find_field(self, name_or_id: str | int, case_sensitive: bool = True) -> NestedField:
def find_field(self, name_or_id: str | int, case_sensitive: bool = True) -> NestedField | None:
"""Find a field using a field name or field ID

Args:
Expand All @@ -147,13 +139,12 @@ def find_field(self, name_or_id: str | int, case_sensitive: bool = True) -> Nest
NestedField: The matched NestedField
"""
if isinstance(name_or_id, int):
field = self._lazy_id_to_field().get(name_or_id)
return field # type: ignore
return self._lazy_id_to_field.get(name_or_id)
if case_sensitive:
field_id = self._name_to_id.get(name_or_id)
else:
field_id = self._lazy_name_to_id_lower().get(name_or_id.lower())
return self._lazy_id_to_field().get(field_id) # type: ignore
field_id = self._lazy_name_to_id_lower.get(name_or_id.lower())
return self._lazy_id_to_field.get(field_id) # type: ignore

def find_type(self, name_or_id: str | int, case_sensitive: bool = True) -> IcebergType:
"""Find a field type using a field name or field ID
Expand All @@ -166,9 +157,11 @@ def find_type(self, name_or_id: str | int, case_sensitive: bool = True) -> Icebe
NestedField: The type of the matched NestedField
"""
field = self.find_field(name_or_id=name_or_id, case_sensitive=case_sensitive)
if not field:
raise ValueError(f"Could not find field with name or id {name_or_id}, case_sensitive={case_sensitive}")
return field.field_type

def find_column_name(self, column_id: int) -> str:
def find_column_name(self, column_id: int) -> str | None:
"""Find a column name given a column ID

Args:
Expand All @@ -177,9 +170,9 @@ def find_column_name(self, column_id: int) -> str:
Returns:
str: The column name (or None if the column ID cannot be found)
"""
return self._lazy_id_to_name().get(column_id) # type: ignore
return self._lazy_id_to_name.get(column_id)

def accessor_for_field(self, field_id: int) -> Accessor:
def accessor_for_field(self, field_id: int) -> Accessor | None:
"""Find a schema position accessor given a field ID

Args:
Expand All @@ -188,7 +181,7 @@ def accessor_for_field(self, field_id: int) -> Accessor:
Returns:
Accessor: An accessor for the given field ID
"""
return self._lazy_id_to_accessor().get(field_id) # type: ignore
return self._lazy_id_to_accessor.get(field_id)

def select(self, names: list[str], case_sensitive: bool = True) -> Schema:
"""Return a new schema instance pruned to a subset of columns
Expand Down