diff --git a/python/src/iceberg/transforms.py b/python/src/iceberg/transforms.py index a2fa8f67fd37..aafda5df8058 100644 --- a/python/src/iceberg/transforms.py +++ b/python/src/iceberg/transforms.py @@ -19,8 +19,13 @@ import struct from abc import ABC, abstractmethod from decimal import Decimal -from functools import singledispatchmethod -from typing import Generic, Optional, TypeVar +from functools import singledispatch +from typing import ( + Any, + Generic, + Optional, + TypeVar, +) from uuid import UUID import mmh3 # type: ignore @@ -40,7 +45,7 @@ UUIDType, ) from iceberg.utils import datetime -from iceberg.utils.decimal import decimal_to_bytes +from iceberg.utils.decimal import decimal_to_bytes, truncate_decimal from src.iceberg.utils.singleton import Singleton S = TypeVar("S") @@ -267,39 +272,132 @@ def satisfies_order_of(self, other: Transform) -> bool: return other.preserves_order def to_human_string(self, value: Optional[S]) -> str: - return self._human_string(value) + return _human_string(value, self._type) if value is not None else "null" - @singledispatchmethod - def _human_string(self, value: Optional[S]) -> str: - return str(value) if value is not None else "null" - @_human_string.register(bytes) - def _(self, value: bytes) -> str: - return _base64encode(value) +class TruncateTransform(Transform[S, S]): + """A transform for truncating a value to a specified width. + Args: + source_type (Type): An Iceberg Type of IntegerType, LongType, StringType, BinaryType or DecimalType + width (int): The truncate width + Raises: + ValueError: If a type is provided that is incompatible with a Truncate transform + """ + + def __init__(self, source_type: IcebergType, width: int): + assert width > 0, f"width ({width}) should be greater than 0" + super().__init__( + f"truncate[{width}]", + f"transforms.truncate(source_type={repr(source_type)}, width={width})", + ) + self._type = source_type + self._width = width - @_human_string.register(int) - def _(self, value: int) -> str: - return self._int_to_human_string(self._type, value) + @property + def width(self) -> int: + return self._width - @singledispatchmethod - def _int_to_human_string(self, _: IcebergType, value: int) -> str: - return str(value) + @property + def type(self) -> IcebergType: + return self._type - @_int_to_human_string.register(DateType) - def _(self, _: IcebergType, value: int) -> str: - return datetime.to_human_day(value) + def apply(self, value: Optional[S]) -> Optional[S]: + return _truncate_value(value, self._width) if value is not None else None - @_int_to_human_string.register(TimeType) - def _(self, _: IcebergType, value: int) -> str: - return datetime.to_human_time(value) + def can_transform(self, source: IcebergType) -> bool: + return self._type == source + + def result_type(self, source: IcebergType) -> IcebergType: + return source + + @property + def preserves_order(self) -> bool: + return True + + def satisfies_order_of(self, other: Transform) -> bool: + if self == other: + return True + elif isinstance(self._type, StringType) and isinstance(other, TruncateTransform) and isinstance(other.type, StringType): + return self._width >= other.width + + return False + + def to_human_string(self, value: Optional[S]) -> str: + if value is None: + return "null" + elif isinstance(value, bytes): + return _base64encode(value) + else: + return str(value) - @_int_to_human_string.register(TimestampType) - def _(self, _: IcebergType, value: int) -> str: - return datetime.to_human_timestamp(value) - @_int_to_human_string.register(TimestamptzType) - def _(self, _: IcebergType, value: int) -> str: - return datetime.to_human_timestamptz(value) +@singledispatch +def _human_string(value: Any, _type: IcebergType) -> str: + return str(value) + + +@_human_string.register(bytes) +def _(value: bytes, _type: IcebergType) -> str: + return _base64encode(value) + + +@_human_string.register(int) +def _(value: int, _type: IcebergType) -> str: + return _int_to_human_string(_type, value) + + +@singledispatch +def _int_to_human_string(_type: IcebergType, value: int) -> str: + return str(value) + + +@_int_to_human_string.register(DateType) +def _(_type: IcebergType, value: int) -> str: + return datetime.to_human_day(value) + + +@_int_to_human_string.register(TimeType) +def _(_type: IcebergType, value: int) -> str: + return datetime.to_human_time(value) + + +@_int_to_human_string.register(TimestampType) +def _(_type: IcebergType, value: int) -> str: + return datetime.to_human_timestamp(value) + + +@_int_to_human_string.register(TimestamptzType) +def _(_type: IcebergType, value: int) -> str: + return datetime.to_human_timestamptz(value) + + +@singledispatch +def _truncate_value(value: Any, _width: int) -> S: + raise ValueError(f"Cannot truncate value: {value}") + + +@_truncate_value.register(int) +def _(value: int, _width: int) -> int: + """Truncate a given int value into a given width if feasible.""" + return value - value % _width + + +@_truncate_value.register(str) +def _(value: str, _width: int) -> str: + """Truncate a given string to a given width.""" + return value[0 : min(_width, len(value))] + + +@_truncate_value.register(bytes) +def _(value: bytes, _width: int) -> bytes: + """Truncate a given binary bytes into a given width.""" + return value[0 : min(_width, len(value))] + + +@_truncate_value.register(Decimal) +def _(value: Decimal, _width: int) -> Decimal: + """Truncate a given decimal value into a given width.""" + return truncate_decimal(value, _width) class UnknownTransform(Transform): @@ -369,5 +467,9 @@ def identity(source_type: IcebergType) -> IdentityTransform: return IdentityTransform(source_type) +def truncate(source_type: IcebergType, width: int) -> TruncateTransform: + return TruncateTransform(source_type, width) + + def always_null() -> Transform: return VoidTransform() diff --git a/python/src/iceberg/utils/decimal.py b/python/src/iceberg/utils/decimal.py index 1d4c2bddefd0..40bc087390c3 100644 --- a/python/src/iceberg/utils/decimal.py +++ b/python/src/iceberg/utils/decimal.py @@ -75,3 +75,16 @@ def decimal_to_bytes(value: Decimal) -> bytes: """ unscaled_value = decimal_to_unscaled(value) return unscaled_value.to_bytes(bytes_required(unscaled_value), byteorder="big", signed=True) + + +def truncate_decimal(value: Decimal, width: int) -> Decimal: + """Get a truncated Decimal value given a decimal value and a width + Args: + value (Decimal): a decimal value + width (int): A width for the returned Decimal instance + Returns: + Decimal: A truncated Decimal instance + """ + unscaled_value = decimal_to_unscaled(value) + applied_value = unscaled_value - (((unscaled_value % width) + width) % width) + return unscaled_to_decimal(applied_value, -value.as_tuple().exponent) diff --git a/python/tests/test_transforms.py b/python/tests/test_transforms.py index d7b8e968a9fd..dc3ce4ec2737 100644 --- a/python/tests/test_transforms.py +++ b/python/tests/test_transforms.py @@ -177,6 +177,63 @@ def test_identity_method(type_var): assert identity_transform.apply("test") == "test" +@pytest.mark.parametrize("type_var", [IntegerType(), LongType()]) +@pytest.mark.parametrize( + "input_var,expected", + [(1, 0), (5, 0), (9, 0), (10, 10), (11, 10), (-1, -10), (-10, -10), (-12, -20)], +) +def test_truncate_integer(type_var, input_var, expected): + trunc = transforms.truncate(type_var, 10) + assert trunc.apply(input_var) == expected + + +@pytest.mark.parametrize( + "input_var,expected", + [ + (Decimal("12.34"), Decimal("12.30")), + (Decimal("12.30"), Decimal("12.30")), + (Decimal("12.29"), Decimal("12.20")), + (Decimal("0.05"), Decimal("0.00")), + (Decimal("-0.05"), Decimal("-0.10")), + ], +) +def test_truncate_decimal(input_var, expected): + trunc = transforms.truncate(DecimalType(9, 2), 10) + assert trunc.apply(input_var) == expected + + +@pytest.mark.parametrize("input_var,expected", [("abcdefg", "abcde"), ("abc", "abc")]) +def test_truncate_string(input_var, expected): + trunc = transforms.truncate(StringType(), 5) + assert trunc.apply(input_var) == expected + + +@pytest.mark.parametrize( + "type_var,value,expected_human_str,expected", + [ + (BinaryType(), b"\x00\x01\x02\x03", "AAECAw==", b"\x00"), + (BinaryType(), bytes("\u2603de", "utf-8"), "4piDZGU=", b"\xe2"), + (DecimalType(8, 5), Decimal("14.21"), "14.21", Decimal("14.21")), + (IntegerType(), 123, "123", 123), + (LongType(), 123, "123", 123), + (StringType(), "foo", "foo", "f"), + (StringType(), "\u2603de", "\u2603de", "\u2603"), + ], +) +def test_truncate_method(type_var, value, expected_human_str, expected): + truncate_transform = transforms.truncate(type_var, 1) + assert str(truncate_transform) == str(eval(repr(truncate_transform))) + assert truncate_transform.can_transform(type_var) + assert truncate_transform.result_type(type_var) == type_var + assert truncate_transform.to_human_string(value) == expected_human_str + assert truncate_transform.apply(value) == expected + assert truncate_transform.to_human_string(None) == "null" + assert truncate_transform.width == 1 + assert truncate_transform.apply(None) is None + assert truncate_transform.preserves_order + assert truncate_transform.satisfies_order_of(truncate_transform) + + def test_unknown_transform(): unknown_transform = transforms.UnknownTransform(FixedType(8), "unknown") assert str(unknown_transform) == str(eval(repr(unknown_transform)))