Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ package_dir =
packages = find:
python_requires = >=3.7
install_requires =
mmh3
singledispatch
[options.extras_require]
arrow =
Expand Down
240 changes: 240 additions & 0 deletions python/src/iceberg/transforms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,240 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import math
import struct
from decimal import Decimal
from typing import Optional
from uuid import UUID

import mmh3 # type: ignore

from iceberg.types import (
BinaryType,
DateType,
DecimalType,
FixedType,
IcebergType,
IntegerType,
LongType,
StringType,
TimestampType,
TimestamptzType,
TimeType,
UUIDType,
)


class Transform:
"""Transform base class for concrete transforms.

A base class to transform values and project predicates on partition values.
This class is not used directly. Instead, use one of module method to create the child classes.

Args:
transform_string (str): name of the transform type
repr_string (str): string representation of a transform instance
"""

def __init__(self, transform_string: str, repr_string: str):
self._transform_string = transform_string
self._repr_string = repr_string

def __repr__(self):
return self._repr_string

def __str__(self):
return self._transform_string

def __call__(self, value):
return self.apply(value)

def apply(self, value):
raise NotImplementedError()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@samredai, do you think we should convert this to an ABC?


def can_transform(self, source: IcebergType) -> bool:
return False

def result_type(self, source: IcebergType) -> IcebergType:
raise NotImplementedError()

def preserves_order(self) -> bool:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might want this to be @property?

return False

def satisfies_order_of(self, other) -> bool:
return self == other

def to_human_string(self, value) -> str:
if value is None:
return "null"
return str(value)

def dedup_name(self) -> str:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may want to make this a @property.

return self._transform_string


class BaseBucketTransform(Transform):
"""Base Transform class to transform a value into a bucket partition value

Transforms are parameterized by a number of buckets. Bucket partition transforms use a 32-bit
hash of the source value to produce a positive value by mod the bucket number.

Args:
source_type (Type): An Iceberg Type of IntegerType, LongType, DecimalType, DateType, TimeType,
TimestampType, TimestamptzType, StringType, BinaryType, FixedType, UUIDType.
num_buckets (int): The number of buckets.
"""

def __init__(self, source_type: IcebergType, num_buckets: int):
super().__init__(
f"bucket[{num_buckets}]",
f"transforms.bucket(source_type={repr(source_type)}, num_buckets={num_buckets})",
)
self._num_buckets = num_buckets

@property
def num_buckets(self) -> int:
return self._num_buckets

def hash(self, value) -> int:
raise NotImplementedError()

def apply(self, value) -> Optional[int]:
if value is None:
return None

return (self.hash(value) & IntegerType.max) % self._num_buckets

def can_transform(self, source: IcebergType) -> bool:
raise NotImplementedError()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should fall back to the parent's implementation rather than throwing NotImplementedError.


def result_type(self, source: IcebergType) -> IcebergType:
return IntegerType()


class BucketNumberTransform(BaseBucketTransform):
"""Transforms a value of IntegerType, LongType, DateType, TimeType, TimestampType, or TimestamptzType
into a bucket partition value

Example:
>>> transform = BucketNumberTransform(LongType(), 100)
>>> transform.apply(81068000000)
59
"""

def can_transform(self, source: IcebergType) -> bool:
return type(source) in {IntegerType, DateType, LongType, TimeType, TimestampType, TimestamptzType}

def hash(self, value) -> int:
return mmh3.hash(struct.pack("<q", value))


class BucketDecimalTransform(BaseBucketTransform):
"""Transforms a value of DecimalType into a bucket partition value.

Example:
>>> transform = BucketDecimalTransform(DecimalType(9, 2), 100)
>>> transform.apply(Decimal("14.20"))
59
"""

def can_transform(self, source: IcebergType) -> bool:
return isinstance(source, DecimalType)

def hash(self, value: Decimal) -> int:
value_tuple = value.as_tuple()
unscaled_value = int(("-" if value_tuple.sign else "") + "".join([str(d) for d in value_tuple.digits]))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should call iceberg.conversions.decimal_to_unscaled to get the unscaled value.

And min_num_bytes should be ((unscaled_value).bit_length() + 7) // 8. We can make that a utility function as well.

number_of_bytes = int(math.ceil(unscaled_value.bit_length() / 8))
value_in_bytes = unscaled_value.to_bytes(length=number_of_bytes, byteorder="big")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should also use signed=True like conversions:

unscaled_value.to_bytes(min_num_bytes, "big", signed=True)

return mmh3.hash(value_in_bytes)


class BucketStringTransform(BaseBucketTransform):
"""Transforms a value of StringType into a bucket partition value.

Example:
>>> transform = BucketStringTransform(100)
>>> transform.apply("iceberg")
89
"""

def __init__(self, num_buckets: int):
super().__init__(StringType(), num_buckets)

def can_transform(self, source: IcebergType) -> bool:
return isinstance(source, StringType)

def hash(self, value: str) -> int:
return mmh3.hash(value)


class BucketBytesTransform(BaseBucketTransform):
"""Transforms a value of FixedType or BinaryType into a bucket partition value.

Example:
>>> transform = BucketBytesTransform(BinaryType(), 100)
>>> transform.apply(b"\\x00\\x01\\x02\\x03")
41
"""

def can_transform(self, source: IcebergType) -> bool:
return type(source) in {FixedType, BinaryType}

def hash(self, value: bytes) -> int:
return mmh3.hash(value)


class BucketUUIDTransform(BaseBucketTransform):
"""Transforms a value of UUIDType into a bucket partition value.

Example:
>>> transform = BucketUUIDTransform(100)
>>> transform.apply(UUID("f79c3e09-677c-4bbd-a479-3f349cb785e7"))
40
"""

def __init__(self, num_buckets: int):
super().__init__(UUIDType(), num_buckets)

def can_transform(self, source: IcebergType) -> bool:
return isinstance(source, UUIDType)

def hash(self, value: UUID) -> int:
return mmh3.hash(
struct.pack(
">QQ",
(value.int >> 64) & 0xFFFFFFFFFFFFFFFF,
value.int & 0xFFFFFFFFFFFFFFFF,
)
)


def bucket(source_type: IcebergType, num_buckets: int) -> BaseBucketTransform:
if type(source_type) in {IntegerType, LongType, DateType, TimeType, TimestampType, TimestamptzType}:
return BucketNumberTransform(source_type, num_buckets)
elif isinstance(source_type, DecimalType):
return BucketDecimalTransform(source_type, num_buckets)
elif isinstance(source_type, StringType):
return BucketStringTransform(num_buckets)
elif isinstance(source_type, BinaryType):
return BucketBytesTransform(source_type, num_buckets)
elif isinstance(source_type, FixedType):
return BucketBytesTransform(source_type, num_buckets)
elif isinstance(source_type, UUIDType):
return BucketUUIDTransform(num_buckets)
else:
raise ValueError(f"Cannot bucket by type: {source_type}")
133 changes: 133 additions & 0 deletions python/tests/test_transforms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from datetime import datetime
from decimal import Decimal, getcontext
from uuid import UUID

import pytest

from iceberg import transforms
from iceberg.types import (
BinaryType,
DateType,
DecimalType,
FixedType,
IntegerType,
LongType,
StringType,
TimestampType,
TimestamptzType,
TimeType,
UUIDType,
)


@pytest.mark.parametrize(
"test_input,test_type,expected",
[
(1, IntegerType(), 1392991556),
(34, IntegerType(), 2017239379),
(34, LongType(), 2017239379),
(17486, DateType(), -653330422),
(81068000000, TimeType(), -662762989),
(
int(datetime.fromisoformat("2017-11-16T22:31:08+00:00").timestamp() * 1000000),
TimestampType(),
-2047944441,
),
(
int(datetime.fromisoformat("2017-11-16T14:31:08-08:00").timestamp() * 1000000),
TimestamptzType(),
-2047944441,
),
(b"\x00\x01\x02\x03", BinaryType(), -188683207),
(b"\x00\x01\x02\x03", FixedType(4), -188683207),
("iceberg", StringType(), 1210000089),
(UUID("f79c3e09-677c-4bbd-a479-3f349cb785e7"), UUIDType(), 1488055340),
],
)
def test_bucket_hash_values(test_input, test_type, expected):
assert transforms.bucket(test_type, 8).hash(test_input) == expected


@pytest.mark.parametrize(
"test_input,test_type,scale_factor,expected_hash,expected",
[
(Decimal("14.20"), DecimalType(9, 2), Decimal(10) ** -2, -500754589, 59),
(
Decimal("137302769811943318102518958871258.37580"),
DecimalType(38, 5),
Decimal(10) ** -5,
-32334285,
63,
),
],
)
def test_decimal_bucket(test_input, test_type, scale_factor, expected_hash, expected):
getcontext().prec = 38
assert transforms.bucket(test_type, 100).hash(test_input.quantize(scale_factor)) == expected_hash
assert transforms.bucket(test_type, 100).apply(test_input.quantize(scale_factor)) == expected
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this test is needed. No decimal arithmetic is necessary because we can convert directly to the unscaled value using the conversion that @samredai added. There should be no need to change precision and quantize anything, so we don't need tests for it.



@pytest.mark.parametrize(
"bucket,value,expected",
[
(transforms.bucket(IntegerType(), 100), 34, 79),
(transforms.bucket(LongType(), 100), 34, 79),
(transforms.bucket(DateType(), 100), 17486, 26),
(transforms.bucket(TimeType(), 100), 81068000000, 59),
(transforms.bucket(TimestampType(), 100), 1510871468000000, 7),
(transforms.bucket(DecimalType(9, 2), 100), Decimal("14.20"), 59),
(transforms.bucket(StringType(), 100), "iceberg", 89),
(
transforms.bucket(UUIDType(), 100),
UUID("f79c3e09-677c-4bbd-a479-3f349cb785e7"),
40,
),
(transforms.bucket(FixedType(3), 128), b"foo", 32),
(transforms.bucket(BinaryType(), 128), b"\x00\x01\x02\x03", 57),
],
)
def test_buckets(bucket, value, expected):
assert bucket.apply(value) == expected


@pytest.mark.parametrize(
"type_var",
[
BinaryType(),
DateType(),
DecimalType(8, 5),
FixedType(8),
IntegerType(),
LongType(),
StringType(),
TimestampType(),
TimestamptzType(),
TimeType(),
UUIDType(),
],
)
def test_bucket_method(type_var):
bucket_transform = transforms.bucket(type_var, 8)
assert str(bucket_transform) == str(eval(repr(bucket_transform)))
assert bucket_transform.can_transform(type_var)
assert bucket_transform.result_type(type_var) == IntegerType()
assert bucket_transform.num_buckets == 8
assert bucket_transform.apply(None) is None
assert bucket_transform.to_human_string("test") == "test"