diff --git a/python/setup.py b/python/setup.py index 5073c558a252..4d2f3f40ad63 100644 --- a/python/setup.py +++ b/python/setup.py @@ -19,7 +19,7 @@ setup( name="py-iceberg", - install_requires=[], + install_requires=["mmh3", "pytz"], extras_require={ "dev": [ "tox-travis==0.12", diff --git a/python/src/iceberg/transforms.py b/python/src/iceberg/transforms.py new file mode 100644 index 000000000000..5e33d9e31912 --- /dev/null +++ b/python/src/iceberg/transforms.py @@ -0,0 +1,567 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import re +import struct +from decimal import Decimal +from enum import Enum +from typing import Callable, Optional +from uuid import UUID + +import mmh3 # type: ignore + +from iceberg.types import ( + BinaryType, + DateType, + DecimalType, + FixedType, + IcebergType, + IntegerType, + LongType, + StringType, + TimestampType, + TimestamptzType, + TimeType, + Truncatable, + UUIDType, +) +from iceberg.utils import transform_util + + +class Transform: + """Transform base class for concrete transforms. + + A base class to transform values and project predicates on partition values. + This class is not used directly. Instead, use one of module method to create the child classes. + + Args: + transform_string (str): name of the transform type + repr_string (str): string representation of a transform instance + """ + + def __init__(self, transform_string: str, repr_string: str): + self._transform_string = transform_string + self._repr_string = repr_string + + def __repr__(self): + return self._repr_string + + def __str__(self): + return self._transform_string + + def __call__(self, value): + return self.apply(value) + + def apply(self, value): + raise NotImplementedError() + + def can_transform(self, target: IcebergType) -> bool: + return False + + def result_type(self, source: IcebergType) -> IcebergType: + return source + + def preserves_order(self) -> bool: + return False + + def satisfies_order_of(self, other) -> bool: + return self == other + + def to_human_string(self, value) -> str: + if value is None: + return "null" + return str(value) + + def dedup_name(self) -> str: + return self._transform_string + + +class BaseBucketTransform(Transform): + """Base Transform class to transform a value into a bucket partition value + + Transforms are parameterized by a number of buckets. Bucket partition transforms use a 32-bit + hash of the source value to produce a positive value by mod the bucket number. + + Args: + source_type (Type): An Iceberg Type of IntegerType, LongType, DecimalType, DateType, TimeType, + TimestampType, TimestamptzType, StringType, BinaryType, FixedType, UUIDType. + num_buckets (int): The number of buckets. + """ + + def __init__(self, source_type: IcebergType, num_buckets: int): + super().__init__( + f"bucket[{num_buckets}]", + f"transforms.bucket(source_type={repr(source_type)}, num_buckets={num_buckets})", + ) + self._num_buckets = num_buckets + + @property + def num_buckets(self) -> int: + return self._num_buckets + + def hash(self, value) -> int: + raise NotImplementedError() + + def apply(self, value) -> Optional[int]: + if value is None: + return None + + return (self.hash(value) & IntegerType.max) % self._num_buckets + + def can_transform(self, target: IcebergType) -> bool: + raise NotImplementedError() + + def result_type(self, source: IcebergType) -> IcebergType: + return IntegerType() + + +class BucketIntegerTransform(BaseBucketTransform): + """Transforms a value of IntegerType or DateType into a bucket partition value + + Example: + >>> transform = BucketIntegerTransform(100) + >>> transform.apply(34) + 79 + """ + + def can_transform(self, target: IcebergType) -> bool: + return type(target) in {IntegerType, DateType} + + def hash(self, value) -> int: + return mmh3.hash(struct.pack("q", value)) + + +class BucketLongTransform(BaseBucketTransform): + """Transforms a value of LongType, TimeType, TimestampType, or TimestamptzType + into a bucket partition value + + Example: + >>> transform = BucketLongTransform(100) + >>> transform.apply(81068000000) + 59 + """ + + def can_transform(self, target: IcebergType) -> bool: + return type(target) in {LongType, TimeType, TimestampType, TimestamptzType} + + def hash(self, value) -> int: + return mmh3.hash(struct.pack("q", value)) + + +class BucketDoubleTransform(BaseBucketTransform): + """Transforms a value of FloatType or DoubleType into a bucket partition value. + + Note that bucketing by Double is not allowed by the spec, but this has the hash implementation. + + Example: + >>> transform = BucketDoubleTransform(8) + >>> transform.hash(1.0) + -142385009 + """ + + def hash(self, value) -> int: + return mmh3.hash(struct.pack("d", value)) + + +class BucketDecimalTransform(BaseBucketTransform): + """Transforms a value of DecimalType into a bucket partition value. + + Example: + >>> transform = BucketDecimalTransform(100) + >>> transform.apply(Decimal("14.20")) + 59 + """ + + def can_transform(self, target: IcebergType) -> bool: + return isinstance(target, DecimalType) + + def hash(self, value: Decimal) -> int: + return mmh3.hash(transform_util.decimal_to_bytes(value)) + + +class BucketStringTransform(BaseBucketTransform): + """Transforms a value of StringType into a bucket partition value. + + Example: + >>> transform = BucketStringTransform(100) + >>> transform.apply("iceberg") + 89 + """ + + def can_transform(self, target: IcebergType) -> bool: + return isinstance(target, StringType) + + def hash(self, value: str) -> int: + return mmh3.hash(value) + + +class BucketFixedTransform(BaseBucketTransform): + """Transforms a value of FixedType into a bucket partition value. + + Example: + >>> transform = BucketFixedTransform(128) + >>> transform.apply(b"foo") + 32 + """ + + def can_transform(self, target: IcebergType) -> bool: + return isinstance(target, FixedType) + + def hash(self, value: bytearray) -> int: + return mmh3.hash(value) + + +class BucketBinaryTransform(BaseBucketTransform): + """Transforms a value of BinaryType into a bucket partition value. + + Example: + >>> transform = BucketBinaryTransform(128) + >>> transform.apply(b"\x00\x01\x02\x03") + 57 + """ + + def can_transform(self, target: IcebergType) -> bool: + return isinstance(target, BinaryType) + + def hash(self, value: bytes) -> int: + return mmh3.hash(value) + + +class BucketUUIDTransform(BaseBucketTransform): + """Transforms a value of UUIDType into a bucket partition value. + + Example: + >>> transform = BucketUUIDTransform(100) + >>> transform.apply(UUID("f79c3e09-677c-4bbd-a479-3f349cb785e7")) + 40 + """ + + def can_transform(self, target: IcebergType) -> bool: + return isinstance(target, UUIDType) + + def hash(self, value: UUID) -> int: + return mmh3.hash( + struct.pack( + ">QQ", + (value.int >> 64) & 0xFFFFFFFFFFFFFFFF, + value.int & 0xFFFFFFFFFFFFFFFF, + ) + ) + + +class DateTimeTransform(Transform): + """Base transform class for transforms of DateType, TimestampType, and TimestamptzType.""" + + class Granularity(Enum): + def __init__(self, order: int, result_type: IcebergType, human_string: Callable[[int], str]): + self.order = order + self.result_type = result_type + self.human_string = human_string + + YEAR = 3, IntegerType(), transform_util.human_year + MONTH = 2, IntegerType(), transform_util.human_month + DAY = 1, DateType(), transform_util.human_day + HOUR = 0, IntegerType(), transform_util.human_hour + + _DATE_APPLY_FUNCS = { + Granularity.YEAR: transform_util.years_for_days, + Granularity.MONTH: transform_util.months_for_days, + Granularity.DAY: lambda d: d, + } + + _TIMESTAMP_APPLY_FUNCS = { + Granularity.YEAR: transform_util.years_for_ts, + Granularity.MONTH: transform_util.months_for_ts, + Granularity.DAY: transform_util.days_for_ts, + Granularity.HOUR: transform_util.hours_for_ts, + } + + def __init__(self, source_type: IcebergType, name: str): + super().__init__(name, f"transforms.{name}(source_type={repr(source_type)})") + + self._type = source_type + try: + self._granularity = DateTimeTransform.Granularity[name.upper()] + + if isinstance(source_type, DateType): + self._apply = DateTimeTransform._DATE_APPLY_FUNCS[self._granularity] + elif type(source_type) in {TimestampType, TimestamptzType}: + self._apply = DateTimeTransform._TIMESTAMP_APPLY_FUNCS[self._granularity] + else: + raise KeyError + except KeyError: + raise ValueError(f"Cannot partition type {source_type} by {name}") + + def __eq__(self, other): + if type(self) is type(other): + return self._type == other._type and self._granularity == other._granularity + return False + + def can_transform(self, target: IcebergType) -> bool: + if isinstance(self._type, DateType): + return isinstance(target, DateType) + else: # self._type is either TimestampType or TimestamptzType + return not isinstance(target, DateType) + + def apply(self, value: int) -> int: + return self._apply(value) + + def result_type(self, source_type: IcebergType) -> IcebergType: + return self._granularity.result_type + + def preserves_order(self) -> bool: + return True + + def satisfies_order_of(self, other: Transform) -> bool: + if self == other: + return True + + if isinstance(other, DateTimeTransform): + return self._granularity.order <= other._granularity.order + + return False + + def to_human_string(self, value) -> str: + if value is None: + return "null" + return self._granularity.human_string(value) + + def dedup_name(self) -> str: + return "time" + + +class IdentityTransform(Transform): + def __init__(self, source_type: IcebergType): + super().__init__( + "identity", + f"transforms.identity(source_type={repr(source_type)})", + ) + self._type = source_type + + def apply(self, value): + return value + + def can_transform(self, target: IcebergType) -> bool: + return target.is_primitive + + def preserves_order(self) -> bool: + return True + + def satisfies_order_of(self, other: Transform) -> bool: + return other.preserves_order() + + def to_human_string(self, value) -> str: + if value is None: + return "null" + elif isinstance(self._type, DateType): + return transform_util.human_day(value) + elif isinstance(self._type, TimeType): + return transform_util.human_time(value) + elif isinstance(self._type, TimestampType): + return transform_util.human_timestamp(value) + elif isinstance(self._type, TimestamptzType): + return transform_util.human_timestamptz(value) + elif isinstance(self._type, FixedType): + return transform_util.base64encode(value) + elif isinstance(self._type, BinaryType): + return transform_util.base64encode(value) + else: + return str(value) + + +class TruncateTransform(Transform): + """A transform for truncating a value to a specified width. + + Args: + source_type (Type): An Iceberg Truncatable Type of IntegerType, LongType, StringType, BinaryType or DecimalType + width (int): The truncate width + + Raises: + ValueError: If a type is provided that is incompatible with a Truncate transform + """ + + def __init__(self, source_type: IcebergType, width: int): + if not isinstance(source_type, Truncatable): + raise ValueError(f"Cannot truncate type: {source_type}") + + super().__init__( + f"truncate[{width}]", + f"transforms.truncate(source_type={repr(source_type)}, width={width})", + ) + self._type = source_type + self._width = width + + @property + def width(self): + return self._width + + def apply(self, value): + if value is None: + return None + return self._type.truncate(value, self._width) + + def can_transform(self, target: IcebergType) -> bool: + return self._type == target + + def preserves_order(self) -> bool: + return True + + def satisfies_order_of(self, other: Transform) -> bool: + if self == other: + return True + elif isinstance(self._type, StringType) and isinstance(other, TruncateTransform) and isinstance(other._type, StringType): + return self._width >= other._width + + return False + + def to_human_string(self, value) -> str: + if value is None: + return "null" + elif isinstance(self._type, BinaryType): + return transform_util.base64encode(value) + else: + return str(value) + + +class UnknownTransform(Transform): + """A transform that represents when an unknown transform is provided + + Args: + source_type (Type): An Iceberg `Type` + transform (str): A string name of a transform + + Raises: + AttributeError: If the apply method is called. + """ + + def __init__(self, source_type: IcebergType, transform: str): + super().__init__( + transform, + f"UnknownTransform(source_type={repr(source_type)}, transform={repr(transform)})", + ) + self._type = source_type + self._transform = transform + + def apply(self, value): + raise AttributeError(f"Cannot apply unsupported transform: {self}") + + def can_transform(self, target: IcebergType) -> bool: + return self._type == target + + def result_type(self, source: IcebergType) -> IcebergType: + return StringType() + + +class VoidTransform(Transform): + """A transform that always returns None""" + + _instance = None + + def __new__(cls): + if cls._instance is None: + cls._instance = super(VoidTransform, cls).__new__(cls) + return cls._instance + + def __init__(self): + super().__init__("void", "transforms.always_null()") + + def apply(self, value): + return None + + def can_transform(self, target: IcebergType) -> bool: + return True + + def to_human_string(self, value) -> str: + return "null" + + +_HAS_WIDTH = re.compile("(\\w+)\\[(\\d+)\\]") + + +def from_string(source_type: IcebergType, transform: str) -> Transform: + transform_lower = transform.lower() + match = _HAS_WIDTH.match(transform_lower) + + if match is not None: + name = match.group(1) + w = int(match.group(2)) + if name == "truncate": + return TruncateTransform(source_type, w) + elif name == "bucket": + return BaseBucketTransform(source_type, w) + + if transform_lower == "identity": + return identity(source_type) + + try: + return DateTimeTransform(source_type, transform_lower) + except (KeyError, ValueError) as e: + pass # fall through to return unknown transform + + if transform_lower == "void": + return VoidTransform() + + return UnknownTransform(source_type, transform) + + +def identity(source_type: IcebergType) -> IdentityTransform: + return IdentityTransform(source_type) + + +def year(source_type: IcebergType) -> Transform: + return DateTimeTransform(source_type, "year") + + +def month(source_type: IcebergType) -> Transform: + return DateTimeTransform(source_type, "month") + + +def day(source_type: IcebergType) -> Transform: + return DateTimeTransform(source_type, "day") + + +def hour(source_type: IcebergType) -> Transform: + return DateTimeTransform(source_type, "hour") + + +def bucket(source_type: IcebergType, num_buckets: int) -> BaseBucketTransform: + if isinstance(source_type, IntegerType): + return BucketIntegerTransform(source_type, num_buckets) + elif isinstance(source_type, DecimalType): + return BucketDecimalTransform(source_type, num_buckets) + elif isinstance(source_type, DateType): + return BucketIntegerTransform(source_type, num_buckets) + elif type(source_type) in {LongType, TimeType, TimestampType, TimestamptzType}: + return BucketLongTransform(source_type, num_buckets) + elif isinstance(source_type, StringType): + return BucketStringTransform(source_type, num_buckets) + elif isinstance(source_type, BinaryType): + return BucketBinaryTransform(source_type, num_buckets) + elif isinstance(source_type, FixedType): + return BucketFixedTransform(source_type, num_buckets) + elif isinstance(source_type, UUIDType): + return BucketUUIDTransform(source_type, num_buckets) + else: + raise ValueError(f"Cannot bucket by type: {source_type}") + + +def truncate(source_type: IcebergType, width: int) -> TruncateTransform: + return TruncateTransform(source_type, width) + + +def always_null() -> Transform: + return VoidTransform() diff --git a/python/src/iceberg/types.py b/python/src/iceberg/types.py index af1f595663f8..d23696f4147c 100644 --- a/python/src/iceberg/types.py +++ b/python/src/iceberg/types.py @@ -29,7 +29,13 @@ - https://iceberg.apache.org/#spec/#primitive-types """ -from typing import Dict, Optional, Tuple +from abc import ABC, abstractmethod +from decimal import Decimal +from typing import Dict, Generic, Optional, Tuple, TypeVar + +from iceberg.utils import transform_util + +T = TypeVar("T") class Singleton: @@ -41,6 +47,12 @@ def __new__(cls, *args, **kwargs): return cls._instance +class Truncatable(ABC, Generic[T]): + @abstractmethod + def truncate(self, value: T, width: int) -> T: + ... + + class IcebergType: """Base type for all Iceberg Types""" @@ -92,7 +104,7 @@ def length(self) -> int: return self._length -class DecimalType(PrimitiveType): +class DecimalType(PrimitiveType, Truncatable[Decimal]): """A fixed data type in Iceberg. Example: @@ -126,6 +138,10 @@ def precision(self) -> int: def scale(self) -> int: return self._scale + def truncate(self, value: Decimal, width: int) -> Decimal: + """Truncate a given decimal value into a given width.""" + return transform_util.truncate_decimal(value, width) + class NestedField(IcebergType): """Represents a field of a struct, a map key, a map value, or a list element. @@ -340,7 +356,7 @@ def __init__(self): super().__init__("boolean", "BooleanType()") -class IntegerType(PrimitiveType, Singleton): +class IntegerType(PrimitiveType, Singleton, Truncatable[int]): """An Integer data type in Iceberg can be represented using an instance of this class. Integers in Iceberg are 32-bit signed and can be promoted to Longs. @@ -364,8 +380,12 @@ def __init__(self): if not self._initialized: super().__init__("int", "IntegerType()") + def truncate(self, value: int, width: int) -> int: + """Truncate a given int value into a given width.""" + return value - value % width -class LongType(PrimitiveType, Singleton): + +class LongType(PrimitiveType, Singleton, Truncatable[int]): """A Long data type in Iceberg can be represented using an instance of this class. Longs in Iceberg are 64-bit signed integers. @@ -389,6 +409,10 @@ def __init__(self): if not self._initialized: super().__init__("long", "LongType()") + def truncate(self, value: int, width: int) -> int: + """Truncate a given long value into a given width.""" + return value - value % width + class FloatType(PrimitiveType, Singleton): """A Float data type in Iceberg can be represented using an instance of this class. Floats in Iceberg are @@ -480,7 +504,7 @@ def __init__(self): super().__init__("timestamptz", "TimestamptzType()") -class StringType(PrimitiveType, Singleton): +class StringType(PrimitiveType, Singleton, Truncatable[str]): """A String data type in Iceberg can be represented using an instance of this class. Strings in Iceberg are arbitrary-length character sequences and are encoded with UTF-8. @@ -494,6 +518,10 @@ def __init__(self): if not self._initialized: super().__init__("string", "StringType()") + def truncate(self, value: str, width: int) -> str: + """Truncate a given string to a given width.""" + return value[0 : min(width, len(value))] + class UUIDType(PrimitiveType, Singleton): """A UUID data type in Iceberg can be represented using an instance of this class. UUIDs in @@ -510,7 +538,7 @@ def __init__(self): super().__init__("uuid", "UUIDType()") -class BinaryType(PrimitiveType, Singleton): +class BinaryType(PrimitiveType, Singleton, Truncatable[bytearray]): """A Binary data type in Iceberg can be represented using an instance of this class. Binarys in Iceberg are arbitrary-length byte arrays. @@ -523,3 +551,7 @@ class BinaryType(PrimitiveType, Singleton): def __init__(self): if not self._initialized: super().__init__("binary", "BinaryType()") + + def truncate(self, value: bytearray, width: int) -> bytearray: + """Truncate a given binary bytes into a given width.""" + return value[0 : min(width, len(value))] diff --git a/python/src/iceberg/utils/transform_util.py b/python/src/iceberg/utils/transform_util.py new file mode 100644 index 000000000000..98a3d09d5b00 --- /dev/null +++ b/python/src/iceberg/utils/transform_util.py @@ -0,0 +1,110 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import base64 +import math +from datetime import datetime, timedelta +from decimal import Decimal + +import pytz + +_EPOCH = datetime.utcfromtimestamp(0) +_EPOCH_YEAR = _EPOCH.year +_EPOCH_MONTH = _EPOCH.month +_EPOCH_DAY = _EPOCH.day + + +def human_year(year_ordinal: int) -> str: + return "{0:0=4d}".format(_EPOCH_YEAR + year_ordinal) + + +def human_month(month_ordinal: int) -> str: + return "{0:0=4d}-{1:0=2d}".format(_EPOCH_YEAR + int(month_ordinal / 12), 1 + int(month_ordinal % 12)) + + +def human_day(day_ordinal: int) -> str: + time = _EPOCH + timedelta(days=day_ordinal) + return "{0:0=4d}-{1:0=2d}-{2:0=2d}".format(time.year, time.month, time.day) + + +def human_hour(hour_ordinal: int) -> str: + time = _EPOCH + timedelta(hours=hour_ordinal) + return "{0:0=4d}-{1:0=2d}-{2:0=2d}-{3:0=2d}".format(time.year, time.month, time.day, time.hour) + + +def human_time(micros_from_midnight: int) -> str: + day = _EPOCH + timedelta(microseconds=micros_from_midnight) + return f"{day.time()}" + + +def human_timestamptz(timestamp_micros: int) -> str: + day = _EPOCH + timedelta(microseconds=timestamp_micros) + return pytz.timezone("UTC").localize(day).strftime("%Y-%m-%dT%H:%M:%S.%fZ") + + +def human_timestamp(timestamp_micros: int) -> str: + day = _EPOCH + timedelta(microseconds=timestamp_micros) + return day.isoformat() + + +def base64encode(buffer: bytes) -> str: + return base64.b64encode(buffer).decode("ISO-8859-1") + + +def _unscale_decimal(decimal_value: Decimal) -> int: + value_tuple = decimal_value.as_tuple() + return int(("-" if value_tuple.sign else "") + "".join([str(d) for d in value_tuple.digits])) + + +def decimal_to_bytes(value: Decimal) -> bytes: + unscaled_value = _unscale_decimal(value) + number_of_bytes = int(math.ceil(unscaled_value.bit_length() / 8)) + return unscaled_value.to_bytes(length=number_of_bytes, byteorder="big") + + +def truncate_decimal(value: Decimal, width: int) -> Decimal: + unscaled_value = _unscale_decimal(value) + applied_value = unscaled_value - (((unscaled_value % width) + width) % width) + return Decimal(f"{applied_value}e{value.as_tuple().exponent}") + + +def hours_for_ts(timestamp: int) -> int: + return int((datetime.utcfromtimestamp(timestamp / 1000000) - _EPOCH).total_seconds() / 3600) + + +def days_for_ts(timestamp: int) -> int: + return (datetime.utcfromtimestamp(timestamp / 1000000) - _EPOCH).days + + +def months_for_days(days: int) -> int: + dt = datetime.utcfromtimestamp(days * 86400) + return (dt.year - _EPOCH_YEAR) * 12 + (dt.month - _EPOCH_MONTH) - (1 if dt.day < _EPOCH_DAY else 0) + + +def months_for_ts(timestamp: int) -> int: + dt = datetime.utcfromtimestamp(timestamp / 1000000) + return (dt.year - _EPOCH_YEAR) * 12 + (dt.month - _EPOCH_MONTH) - (1 if dt.day < _EPOCH_DAY else 0) + + +def years_for_days(days: int) -> int: + dt = datetime.utcfromtimestamp(days * 86400) + return (dt.year - _EPOCH_YEAR) - (1 if dt.month < _EPOCH_MONTH or (dt.month == _EPOCH_MONTH and dt.day < _EPOCH_DAY) else 0) + + +def years_for_ts(timestamp: int) -> int: + dt = datetime.utcfromtimestamp(timestamp / 1000000) + return (dt.year - _EPOCH_YEAR) - (1 if dt.month < _EPOCH_MONTH or (dt.month == _EPOCH_MONTH and dt.day < _EPOCH_DAY) else 0) diff --git a/python/tests/test_transforms.py b/python/tests/test_transforms.py new file mode 100644 index 000000000000..1bbef025abbd --- /dev/null +++ b/python/tests/test_transforms.py @@ -0,0 +1,495 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from datetime import datetime +from decimal import Decimal, getcontext +from uuid import UUID + +import pytest + +from iceberg import transforms +from iceberg.transforms import BucketDoubleTransform, UnknownTransform +from iceberg.types import ( + BinaryType, + BooleanType, + DateType, + DecimalType, + DoubleType, + FixedType, + FloatType, + IntegerType, + ListType, + LongType, + MapType, + NestedField, + StringType, + StructType, + TimestampType, + TimestamptzType, + TimeType, + UUIDType, +) + + +@pytest.mark.parametrize( + "test_input,test_type,expected", + [ + (1, IntegerType(), 1392991556), + (34, IntegerType(), 2017239379), + (34, LongType(), 2017239379), + (17486, DateType(), -653330422), + (81068000000, TimeType(), -662762989), + ( + int(datetime.fromisoformat("2017-11-16T22:31:08+00:00").timestamp() * 1000000), + TimestampType(), + -2047944441, + ), + ( + int(datetime.fromisoformat("2017-11-16T14:31:08-08:00").timestamp() * 1000000), + TimestamptzType(), + -2047944441, + ), + (b"\x00\x01\x02\x03", BinaryType(), -188683207), + (b"\x00\x01\x02\x03", FixedType(4), -188683207), + ("iceberg", StringType(), 1210000089), + (UUID("f79c3e09-677c-4bbd-a479-3f349cb785e7"), UUIDType(), 1488055340), + ], +) +def test_bucket_hash_values(test_input, test_type, expected): + assert transforms.bucket(test_type, 8).hash(test_input) == expected + + +@pytest.mark.parametrize( + "test_input,test_type,expected", + [ + (1.0, FloatType(), -142385009), + (1.0, DoubleType(), -142385009), + ], +) +def test_spec_double_float_hash(test_input, test_type, expected): + assert BucketDoubleTransform(test_type, 8).hash(test_input) == expected + + +@pytest.mark.parametrize( + "test_input,test_type,scale_factor,expected_hash,expected", + [ + (Decimal("14.20"), DecimalType(9, 2), Decimal(10) ** -2, -500754589, 59), + ( + Decimal("137302769811943318102518958871258.37580"), + DecimalType(38, 5), + Decimal(10) ** -5, + -32334285, + 63, + ), + ], +) +def test_decimal_bucket(test_input, test_type, scale_factor, expected_hash, expected): + getcontext().prec = 38 + assert transforms.bucket(test_type, 100).hash(test_input.quantize(scale_factor)) == expected_hash + assert transforms.bucket(test_type, 100).apply(test_input.quantize(scale_factor)) == expected + + +@pytest.mark.parametrize( + "bucket,value,expected", + [ + (transforms.bucket(IntegerType(), 100), 34, 79), + (transforms.bucket(LongType(), 100), 34, 79), + (transforms.bucket(DateType(), 100), 17486, 26), + (transforms.bucket(TimeType(), 100), 81068000000, 59), + (transforms.bucket(TimestampType(), 100), 1510871468000000, 7), + (transforms.bucket(DecimalType(9, 2), 100), Decimal("14.20"), 59), + (transforms.bucket(StringType(), 100), "iceberg", 89), + ( + transforms.bucket(UUIDType(), 100), + UUID("f79c3e09-677c-4bbd-a479-3f349cb785e7"), + 40, + ), + (transforms.bucket(FixedType(3), 128), b"foo", 32), + (transforms.bucket(BinaryType(), 128), b"\x00\x01\x02\x03", 57), + ], +) +def test_buckets(bucket, value, expected): + assert bucket.apply(value) == expected + + +@pytest.mark.parametrize( + "date,time_transform_name,expected", + [ + (47, "year", "2017"), + (575, "month", "2017-12"), + (17501, "day", "2017-12-01"), + ], +) +def test_time_to_human_string(date, time_transform_name, expected): + assert getattr(transforms, time_transform_name)(DateType()).to_human_string(date) == expected + + +@pytest.mark.parametrize("time_transform_name", ["year", "month", "day", "hour"]) +def test_null_human_string(time_transform_name): + assert getattr(transforms, time_transform_name)(TimestamptzType()).to_human_string(None) == "null" + + +@pytest.mark.parametrize( + "timestamp,time_transform_name,expected", + [ + (47, "year", "2017"), + (575, "month", "2017-12"), + (17501, "day", "2017-12-01"), + (420042, "hour", "2017-12-01-18"), + ], +) +def test_ts_to_human_string(timestamp, time_transform_name, expected): + assert getattr(transforms, time_transform_name)(TimestampType()).to_human_string(timestamp) == expected + + +@pytest.mark.parametrize("time_transform_name", ["year", "month", "day", "hour"]) +def test_null_human_string(time_transform_name): + assert getattr(transforms, time_transform_name)(TimestampType()).to_human_string(None) == "null" + + +@pytest.mark.parametrize( + "timestamp,time_transform_name,expected", + [ + (47, "year", "2017"), + (575, "month", "2017-12"), + (17501, "day", "2017-12-01"), + (420042, "hour", "2017-12-01-18"), + ], +) +def test_ts_to_human_string(timestamp, time_transform_name, expected): + assert getattr(transforms, time_transform_name)(TimestamptzType()).to_human_string(timestamp) == expected + + +@pytest.mark.parametrize("time_transform_name", ["year", "month", "day", "hour"]) +def test_null_human_string(time_transform_name): + assert getattr(transforms, time_transform_name)(TimestamptzType()).to_human_string(None) == "null" + + +@pytest.mark.parametrize( + "type_var,value,expected", + [ + (LongType(), None, "null"), + (DateType(), 17501, "2017-12-01"), + (TimeType(), 36775038194, "10:12:55.038194"), + (TimestamptzType(), 1512151975038194, "2017-12-01T18:12:55.038194Z"), + (TimestampType(), 1512151975038194, "2017-12-01T18:12:55.038194"), + (LongType(), -1234567890000, "-1234567890000"), + (StringType(), "a/b/c=d", "a/b/c=d"), + (DecimalType(9, 2), Decimal("-1.50"), "-1.50"), + (FixedType(100), b"foo", "Zm9v"), + ], +) +def test_identity_human_string(type_var, value, expected): + identity = transforms.identity(type_var) + assert identity.to_human_string(value) == expected + + +@pytest.mark.parametrize("type_var", [IntegerType(), LongType()]) +@pytest.mark.parametrize( + "input_var,expected", + [(1, 0), (5, 0), (9, 0), (10, 10), (11, 10), (-1, -10), (-10, -10), (-12, -20)], +) +def test_truncate_integer(type_var, input_var, expected): + trunc = transforms.truncate(type_var, 10) + assert trunc.apply(input_var) == expected + + +@pytest.mark.parametrize( + "input_var,expected", + [ + (Decimal("12.34"), Decimal("12.30")), + (Decimal("12.30"), Decimal("12.30")), + (Decimal("12.29"), Decimal("12.20")), + (Decimal("0.05"), Decimal("0.00")), + (Decimal("-0.05"), Decimal("-0.10")), + ], +) +def test_truncate_decimal(input_var, expected): + trunc = transforms.truncate(DecimalType(9, 2), 10) + assert trunc.apply(input_var) == expected + + +@pytest.mark.parametrize("input_var,expected", [("abcdefg", "abcde"), ("abc", "abc")]) +def test_truncate_string(input_var, expected): + trunc = transforms.truncate(StringType(), 5) + assert trunc.apply(input_var) == expected + + +@pytest.mark.parametrize( + "type_var", + [ + BinaryType(), + DateType(), + DecimalType(8, 5), + FixedType(8), + IntegerType(), + LongType(), + StringType(), + TimestampType(), + TimestamptzType(), + TimeType(), + UUIDType(), + ], +) +def test_bucket_method(type_var): + bucket_transform = transforms.bucket(type_var, 8) + assert str(bucket_transform) == str(eval(repr(bucket_transform))) + assert bucket_transform.can_transform(type_var) + assert bucket_transform.result_type(type_var) == IntegerType() + assert bucket_transform.num_buckets == 8 + assert bucket_transform.apply(None) is None + assert bucket_transform.to_human_string("test") == "test" + + +@pytest.mark.parametrize( + "type_var,value,expected_human_str,expected", + [ + (BinaryType(), b"foo", "Zm9v", b"f"), + (DecimalType(8, 5), Decimal("14.21"), "14.21", Decimal("14.21")), + (IntegerType(), 123, "123", 123), + (LongType(), 123, "123", 123), + (StringType(), "foo", "foo", "f"), + ], +) +def test_truncate_method(type_var, value, expected_human_str, expected): + truncate_transform = transforms.truncate(type_var, 1) + assert str(truncate_transform) == str(eval(repr(truncate_transform))) + assert truncate_transform.can_transform(type_var) + assert truncate_transform.result_type(type_var) == type_var + assert truncate_transform.to_human_string(value) == expected_human_str + assert truncate_transform.apply(value) == expected + assert truncate_transform.to_human_string(None) == "null" + assert truncate_transform.width == 1 + assert truncate_transform.apply(None) is None + assert truncate_transform.preserves_order() + assert truncate_transform.satisfies_order_of(truncate_transform) + + +@pytest.mark.parametrize( + "type_var", + [ + DateType(), + TimestampType(), + TimestamptzType(), + ], +) +def test_time_methods(type_var): + assert transforms.year(type_var) == eval(repr(transforms.year(type_var))) + assert transforms.month(type_var) == eval(repr(transforms.month(type_var))) + assert transforms.day(type_var) == eval(repr(transforms.day(type_var))) + assert transforms.year(type_var).can_transform(type_var) + assert transforms.month(type_var).can_transform(type_var) + assert transforms.day(type_var).can_transform(type_var) + assert transforms.year(type_var).preserves_order() + assert transforms.month(type_var).preserves_order() + assert transforms.day(type_var).preserves_order() + assert transforms.year(type_var).result_type(type_var) == IntegerType() + assert transforms.month(type_var).result_type(type_var) == IntegerType() + assert transforms.day(type_var).result_type(type_var) == DateType() + assert transforms.year(type_var).dedup_name() == "time" + assert transforms.month(type_var).dedup_name() == "time" + assert transforms.day(type_var).dedup_name() == "time" + + +@pytest.mark.parametrize( + "transform,value,expected", + [ + (transforms.day(DateType()), 17501, 17501), + (transforms.month(DateType()), 17501, 575), + (transforms.year(DateType()), 17501, 47), + (transforms.year(TimestampType()), 1512151975038194, 47), + (transforms.month(TimestamptzType()), 1512151975038194, 575), + (transforms.day(TimestampType()), 1512151975038194, 17501), + ], +) +def test_time_apply_method(transform, value, expected): + assert transform.apply(value) == expected + + +@pytest.mark.parametrize( + "type_var", + [ + TimestampType(), + TimestamptzType(), + ], +) +def test_hour_method(type_var): + assert transforms.hour(type_var) == eval(repr(transforms.hour(type_var))) + assert transforms.hour(type_var).can_transform(type_var) + assert transforms.hour(type_var).result_type(type_var) == IntegerType() + assert transforms.hour(type_var).apply(1512151975038194) == 420042 + assert transforms.hour(type_var).dedup_name() == "time" + + +@pytest.mark.parametrize( + "type_var", + [ + BinaryType(), + BooleanType(), + DateType(), + DecimalType(8, 2), + DoubleType(), + FixedType(16), + FloatType(), + IntegerType(), + LongType(), + StringType(), + TimestampType(), + TimestamptzType(), + TimeType(), + UUIDType(), + ], +) +def test_identity_method(type_var): + identity_transform = transforms.identity(type_var) + assert str(identity_transform) == str(eval(repr(identity_transform))) + assert identity_transform.can_transform(type_var) + assert identity_transform.result_type(type_var) == type_var + assert identity_transform.apply("test") == "test" + + +@pytest.mark.parametrize( + "type_var", + [ + ListType( + 1, + StructType( + NestedField(2, "optional_field", DecimalType(8, 2), is_optional=True), + NestedField(3, "required_field", LongType(), is_optional=False), + ), + False, + ), + MapType(1, DoubleType(), 2, UUIDType(), False), + StructType( + NestedField(1, "optional_field", IntegerType(), is_optional=True), + NestedField(2, "required_field", FixedType(5), is_optional=False), + NestedField( + 3, + "required_field", + StructType( + NestedField(4, "optional_field", DecimalType(8, 2), is_optional=True), + NestedField(5, "required_field", LongType(), is_optional=False), + ), + is_optional=False, + ), + ), + ], +) +def test_identity_nested_type(type_var): + identity_transform = transforms.identity(type_var) + assert str(identity_transform) == str(eval(repr(identity_transform))) + assert not identity_transform.can_transform(type_var) + + +def test_void_transform(): + void_transform = transforms.always_null() + assert void_transform == eval(repr(void_transform)) + assert void_transform.apply("test") is None + assert void_transform.can_transform(BooleanType()) + assert isinstance(void_transform.result_type(BooleanType()), BooleanType) + assert not void_transform.preserves_order() + assert void_transform.satisfies_order_of(transforms.always_null()) + assert not void_transform.satisfies_order_of(transforms.year(DateType())) + assert void_transform.to_human_string("test") == "null" + assert void_transform.dedup_name() == "void" + + +def test_unknown_transform(): + unknown_transform = UnknownTransform(FixedType(8), "unknown") + assert str(unknown_transform) == str(eval(repr(unknown_transform))) + with pytest.raises(AttributeError): + unknown_transform.apply("test") + assert unknown_transform.can_transform(FixedType(8)) + assert not unknown_transform.can_transform(FixedType(5)) + assert isinstance(unknown_transform.result_type(BooleanType()), StringType) + + +@pytest.mark.parametrize( + "type_var,transform,expected", + [ + (BinaryType(), "bucket[100]", transforms.bucket(BinaryType(), 100)), + (BooleanType(), "identity", transforms.identity(BooleanType())), + (DateType(), "year", transforms.year(DateType())), + (DecimalType(8, 2), "truncate[5]", transforms.truncate(DecimalType(8, 2), 5)), + (DoubleType(), "identity", transforms.identity(DoubleType())), + (FixedType(16), "bucket[32]", transforms.bucket(FixedType(16), 32)), + (FloatType(), "identity", transforms.identity(FloatType())), + (IntegerType(), "void", transforms.always_null()), + (LongType(), "bucket[16]", transforms.bucket(LongType(), 16)), + (StringType(), "truncate[8]", transforms.truncate(StringType(), 8)), + (TimestampType(), "month", transforms.month(TimestampType())), + (TimestamptzType(), "hour", transforms.hour(TimestamptzType())), + (TimeType(), "day", UnknownTransform(TimeType(), "day")), + (UUIDType(), "bucket[16]", transforms.bucket(UUIDType(), 16)), + ], +) +def test_from_string(type_var, transform, expected): + assert repr(transforms.from_string(type_var, transform)) == repr(expected) + assert transform == str(expected) + + +@pytest.mark.parametrize( + "transform,other_transform,expected", + [ + (transforms.identity(BooleanType()), transforms.identity(IntegerType()), True), + (transforms.identity(BooleanType()), transforms.always_null(), False), + (transforms.year(DateType()), transforms.year(DateType()), True), + (transforms.year(DateType()), transforms.month(DateType()), False), + (transforms.year(DateType()), transforms.day(DateType()), False), + (transforms.year(DateType()), transforms.hour(TimestampType()), False), + (transforms.hour(TimestampType()), transforms.month(DateType()), True), + (transforms.day(TimestamptzType()), transforms.month(DateType()), True), + (transforms.day(TimestamptzType()), transforms.always_null(), False), + ( + transforms.truncate(StringType(), 8), + transforms.truncate(StringType(), 16), + False, + ), + ( + transforms.truncate(StringType(), 16), + transforms.truncate(StringType(), 16), + True, + ), + ( + transforms.truncate(StringType(), 32), + transforms.truncate(StringType(), 16), + True, + ), + ( + transforms.truncate(StringType(), 16), + transforms.truncate(IntegerType(), 8), + False, + ), + ], +) +def test_satisfies_order_of(transform, other_transform, expected): + assert transform.satisfies_order_of(other_transform) == expected + + +def test_invalid_cases(): + with pytest.raises(ValueError): + transforms.hour(DateType()) + with pytest.raises(ValueError): + transforms.day(IntegerType()) + with pytest.raises(ValueError): + transforms.month(BinaryType()) + with pytest.raises(ValueError): + transforms.year(UUIDType()) + with pytest.raises(ValueError): + transforms.bucket(BooleanType(), 8) + with pytest.raises(ValueError): + transforms.truncate(UUIDType(), 8) diff --git a/python/tox.ini b/python/tox.ini index 214df438c9b5..d287d27f6ed2 100644 --- a/python/tox.ini +++ b/python/tox.ini @@ -66,7 +66,7 @@ commands = deps = mypy commands = - mypy --no-implicit-optional --config tox.ini src + mypy --install-types --non-interactive --no-implicit-optional --config tox.ini src [testenv:docs] basepython = python3