diff --git a/python/iceberg/api/partition_spec.py b/python/iceberg/api/partition_spec.py index 95cce0e6a679..bcea93f2af8b 100644 --- a/python/iceberg/api/partition_spec.py +++ b/python/iceberg/api/partition_spec.py @@ -236,7 +236,7 @@ def year(self, source_name): source_column = self.find_source_column(source_name) self.fields.append(PartitionField(source_column.field_id, name, - Transforms.year(source_column.types))) + Transforms.year(source_column.type))) return self def month(self, source_name): @@ -245,7 +245,7 @@ def month(self, source_name): source_column = self.find_source_column(source_name) self.fields.append(PartitionField(source_column.field_id, name, - Transforms.month(source_column.types))) + Transforms.month(source_column.type))) return self def day(self, source_name): @@ -254,7 +254,7 @@ def day(self, source_name): source_column = self.find_source_column(source_name) self.fields.append(PartitionField(source_column.field_id, name, - Transforms.day(source_column.types))) + Transforms.day(source_column.type))) return self def hour(self, source_name): @@ -281,7 +281,7 @@ def truncate(self, source_name, width): source_column = self.find_source_column(source_name) self.fields.append(PartitionField(source_column.field_id, name, - Transforms.truncate(source_column.types, width))) + Transforms.truncate(source_column.type, width))) return self def add(self, source_id, name, transform): diff --git a/python/iceberg/api/transforms/bucket.py b/python/iceberg/api/transforms/bucket.py index c1672da4a4eb..64db7122089e 100644 --- a/python/iceberg/api/transforms/bucket.py +++ b/python/iceberg/api/transforms/bucket.py @@ -101,7 +101,7 @@ def hash(self, value): return Bucket.MURMUR3.hash(struct.pack("q", value)) def can_transform(self, type_var): - return type_var.type_id() in [TypeID.INTEGER, TypeID.DATE] + return type_var.type_id in [TypeID.INTEGER, TypeID.DATE] class BucketLong(Bucket): @@ -112,9 +112,9 @@ def hash(self, value): return Bucket.MURMUR3.hash(struct.pack("q", value)) def can_transform(self, type_var): - return type_var.type_id() in [TypeID.LONG, - TypeID.TIME, - TypeID.TIMESTAMP] + return type_var.type_id in [TypeID.LONG, + TypeID.TIME, + TypeID.TIMESTAMP] class BucketFloat(Bucket): @@ -125,7 +125,7 @@ def hash(self, value): return Bucket.MURMUR3.hash(struct.pack("d", value)) def can_transform(self, type_var): - return type_var.type_id() == TypeID.FLOAT + return type_var.type_id == TypeID.FLOAT class BucketDouble(Bucket): @@ -136,7 +136,7 @@ def hash(self, value): return Bucket.MURMUR3.hash(struct.pack("d", value)) def can_transform(self, type_var): - return type_var.type_id() == TypeID.DOUBLE + return type_var.type_id == TypeID.DOUBLE class BucketDecimal(Bucket): diff --git a/python/iceberg/api/transforms/dates.py b/python/iceberg/api/transforms/dates.py index 628281b20b81..474b986f696d 100644 --- a/python/iceberg/api/transforms/dates.py +++ b/python/iceberg/api/transforms/dates.py @@ -52,7 +52,7 @@ def apply(self, days): return apply_func(datetime.datetime.utcfromtimestamp(days * Dates.SECONDS_IN_DAY), Dates.EPOCH) def can_transform(self, type): - return type.type_id() == TypeID.DATE + return type.type_id == TypeID.DATE def get_result_type(self, source_type): return IntegerType.get() @@ -73,4 +73,4 @@ def to_human_string(self, value): return Dates.HUMAN_FUNCS[self.granularity](value) def __str__(self): - return "%s" % self + return self.name diff --git a/python/iceberg/api/transforms/timestamps.py b/python/iceberg/api/transforms/timestamps.py index 697cec67e594..25c4439bc179 100644 --- a/python/iceberg/api/transforms/timestamps.py +++ b/python/iceberg/api/transforms/timestamps.py @@ -50,7 +50,7 @@ def apply(self, value): return apply_func(datetime.datetime.utcfromtimestamp(value / 1000000), Timestamps.EPOCH) def can_transform(self, type_var): - return type_var == TypeID.TIMESTAMP + return type_var.type_id == TypeID.TIMESTAMP def get_result_type(self, source_type): return IntegerType.get() diff --git a/python/iceberg/api/transforms/transforms.py b/python/iceberg/api/transforms/transforms.py index c14d84930f6b..33877f1fd806 100644 --- a/python/iceberg/api/transforms/transforms.py +++ b/python/iceberg/api/transforms/transforms.py @@ -42,22 +42,22 @@ def __init__(self): pass @staticmethod - def from_string(type, transform): + def from_string(type_var, transform): match = Transforms.HAS_WIDTH.match(transform) if match is not None: name = match.group(1) w = match.group(2) if name.lower() == "truncate": - return Truncate.get(type, w) + return Truncate.get(type_var, w) elif name.lower() == "bucket": - return Bucket.get(type, w) + return Bucket.get(type_var, w) if transform.lower() == "identity": - return Identity.get(type) - elif type.type_id() == TypeID.TIMESTAMP: + return Identity.get(type_var) + elif type_var.type_id == TypeID.TIMESTAMP: return Timestamps(transform.lower(), transform.lower()) - elif type.type_id() == TypeID.DATE: + elif type_var.type_id == TypeID.DATE: return Dates(transform.lower(), transform.lower()) raise RuntimeError("Unknown transform: %s" % transform) @@ -108,4 +108,4 @@ def bucket(type_var, num_buckets): @staticmethod def truncate(type_var, width): - return Truncate.get(type, width) + return Truncate.get(type_var, width) diff --git a/python/iceberg/api/transforms/truncate.py b/python/iceberg/api/transforms/truncate.py index cd001ad44eac..b37e6f0b06dd 100644 --- a/python/iceberg/api/transforms/truncate.py +++ b/python/iceberg/api/transforms/truncate.py @@ -32,7 +32,7 @@ def get(type_var, width): if type_var.type_id == TypeID.INTEGER: return TruncateInteger(width) elif type_var.type_id == TypeID.LONG: - return TruncateInteger(width) + return TruncateLong(width) elif type_var.type_id == TypeID.DECIMAL: return TruncateDecimal(width) elif type_var.type_id == TypeID.STRING: diff --git a/python/tests/core/test_partition_spec.py b/python/tests/core/test_partition_spec.py new file mode 100644 index 000000000000..09d108cf3c34 --- /dev/null +++ b/python/tests/core/test_partition_spec.py @@ -0,0 +1,109 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from iceberg.api import PartitionSpec, Schema +from iceberg.api.types import (BinaryType, + DateType, + DecimalType, + FixedType, + IntegerType, + LongType, + NestedField, + StringType, + TimestampType, + TimeType, + UUIDType) + + +def test_to_json_conversion(): + spec_schema = Schema(NestedField.required(1, "i", IntegerType.get()), + NestedField.required(2, "l", LongType.get()), + NestedField.required(3, "d", DateType.get()), + NestedField.required(4, "t", TimeType.get()), + NestedField.required(5, "ts", TimestampType.without_timezone()), + NestedField.required(6, "dec", DecimalType.of(9, 2)), + NestedField.required(7, "s", StringType.get()), + NestedField.required(8, "u", UUIDType.get()), + NestedField.required(9, "f", FixedType.of_length(3)), + NestedField.required(10, "b", BinaryType.get())) + + specs = [ + PartitionSpec.builder_for(spec_schema).identity("i").build(), + PartitionSpec.builder_for(spec_schema).identity("l").build(), + PartitionSpec.builder_for(spec_schema).identity("d").build(), + PartitionSpec.builder_for(spec_schema).identity("t").build(), + PartitionSpec.builder_for(spec_schema).identity("ts").build(), + PartitionSpec.builder_for(spec_schema).identity("dec").build(), + PartitionSpec.builder_for(spec_schema).identity("s").build(), + PartitionSpec.builder_for(spec_schema).identity("u").build(), + PartitionSpec.builder_for(spec_schema).identity("f").build(), + PartitionSpec.builder_for(spec_schema).identity("b").build(), + PartitionSpec.builder_for(spec_schema).bucket("i", 128).build(), + PartitionSpec.builder_for(spec_schema).bucket("l", 128).build(), + PartitionSpec.builder_for(spec_schema).bucket("d", 128).build(), + PartitionSpec.builder_for(spec_schema).bucket("t", 128).build(), + PartitionSpec.builder_for(spec_schema).bucket("ts", 128).build(), + PartitionSpec.builder_for(spec_schema).bucket("dec", 128).build(), + PartitionSpec.builder_for(spec_schema).bucket("s", 128).build(), + PartitionSpec.builder_for(spec_schema).year("d").build(), + PartitionSpec.builder_for(spec_schema).month("d").build(), + PartitionSpec.builder_for(spec_schema).day("d").build(), + PartitionSpec.builder_for(spec_schema).year("ts").build(), + PartitionSpec.builder_for(spec_schema).month("ts").build(), + PartitionSpec.builder_for(spec_schema).day("ts").build(), + PartitionSpec.builder_for(spec_schema).hour("ts").build(), + PartitionSpec.builder_for(spec_schema).truncate("i", 10).build(), + PartitionSpec.builder_for(spec_schema).truncate("l", 10).build(), + PartitionSpec.builder_for(spec_schema).truncate("dec", 10).build(), + PartitionSpec.builder_for(spec_schema).truncate("s", 10).build(), + PartitionSpec.builder_for(spec_schema).add(6, "dec_bucket", "bucket[16]").build() + ] + + expected_spec_strs = [ + "[\n i: identity(1)\n]", + "[\n l: identity(2)\n]", + "[\n d: identity(3)\n]", + "[\n t: identity(4)\n]", + "[\n ts: identity(5)\n]", + "[\n dec: identity(6)\n]", + "[\n s: identity(7)\n]", + "[\n u: identity(8)\n]", + "[\n f: identity(9)\n]", + "[\n b: identity(10)\n]", + "[\n i_bucket: bucket[128](1)\n]", + "[\n l_bucket: bucket[128](2)\n]", + "[\n d_bucket: bucket[128](3)\n]", + "[\n t_bucket: bucket[128](4)\n]", + "[\n ts_bucket: bucket[128](5)\n]", + "[\n dec_bucket: bucket[128](6)\n]", + "[\n s_bucket: bucket[128](7)\n]", + "[\n d_year: year(3)\n]", + "[\n d_month: month(3)\n]", + "[\n d_day: day(3)\n]", + "[\n ts_year: year(5)\n]", + "[\n ts_month: month(5)\n]", + "[\n ts_day: day(5)\n]", + "[\n ts_hour: hour(5)\n]", + "[\n i_truncate: truncate[10](1)\n]", + "[\n l_truncate: truncate[10](2)\n]", + "[\n dec_truncate: truncate[10](6)\n]", + "[\n s_truncate: truncate[10](7)\n]", + "[\n dec_bucket: bucket[16](6)\n]", + ] + + for (spec, expected_spec_str) in zip(specs, expected_spec_strs): + assert str(spec) == expected_spec_str