diff --git a/python/iceberg/api/partition_spec.py b/python/iceberg/api/partition_spec.py index 395001b60162..e2f106d76b3b 100644 --- a/python/iceberg/api/partition_spec.py +++ b/python/iceberg/api/partition_spec.py @@ -21,7 +21,7 @@ from .partition_field import PartitionField from .schema import Schema -from .transforms import Transforms +from .transforms import Transform, Transforms from .types import (NestedField, StructType) @@ -202,6 +202,7 @@ def __init__(self, schema): self.schema = schema self.fields = list() self.partition_names = set() + self.dedup_fields = dict() self.spec_id = 0 self.last_assigned_field_id = PartitionSpec.PARTITION_DATA_ID_START - 1 @@ -213,15 +214,35 @@ def with_spec_id(self, spec_id): self.spec_id = spec_id return self - def check_and_add_partition_name(self, name): + def check_and_add_partition_name(self, name, source_column_id=None): + schema_field = self.schema.find_field(name) + if source_column_id is not None: + if schema_field is not None and schema_field.field_id != source_column_id: + raise ValueError("Cannot create identity partition sourced from different field in schema: %s" % name) + else: + if schema_field is not None: + raise ValueError("Cannot create partition from name that exists in schema: %s" % name) + if name is None or name == "": - raise RuntimeError("Cannot use empty or null partition name") + raise ValueError("Cannot use empty or null partition name: %s" % name) if name in self.partition_names: - raise RuntimeError("Cannot use partition names more than once: %s" % name) + raise ValueError("Cannot use partition names more than once: %s" % name) self.partition_names.add(name) return self + def check_redundant_and_add_field(self, field_id: int, name: str, transform: Transform) -> None: + field = PartitionField(field_id, + self.__next_field_id(), + name, + transform) + dedup_key = (field.source_id, field.transform.dedup_name()) + partition_field = self.dedup_fields.get(dedup_key) + if partition_field is not None: + raise ValueError("Cannot add redundant partition: %s conflicts with %s" % (partition_field, field)) + self.dedup_fields[dedup_key] = field + self.fields.append(field) + def find_source_column(self, source_name): source_column = self.schema.find_field(source_name) if source_column is None: @@ -229,72 +250,82 @@ def find_source_column(self, source_name): return source_column - def identity(self, source_name): - self.check_and_add_partition_name(source_name) + def identity(self, source_name, target_name=None): + if target_name is None: + target_name = source_name + source_column = self.find_source_column(source_name) - self.fields.append(PartitionField(source_column.field_id, - self.__next_field_id(), - source_name, - Transforms.identity(source_column.type))) + self.check_and_add_partition_name(target_name, source_column.field_id) + self.check_redundant_and_add_field(source_column.field_id, + target_name, + Transforms.identity(source_column.type)) return self - def year(self, source_name): - name = "{}_year".format(source_name) - self.check_and_add_partition_name(name) + def year(self, source_name, target_name=None): + if target_name is None: + target_name = "{}_year".format(source_name) + + self.check_and_add_partition_name(target_name) source_column = self.find_source_column(source_name) - self.fields.append(PartitionField(source_column.field_id, - self.__next_field_id(), - name, - Transforms.year(source_column.type))) + self.check_redundant_and_add_field(source_column.field_id, + target_name, + Transforms.year(source_column.type)) return self - def month(self, source_name): - name = "{}_month".format(source_name) - self.check_and_add_partition_name(name) + def month(self, source_name, target_name=None): + if target_name is None: + target_name = "{}_month".format(source_name) + + self.check_and_add_partition_name(target_name) source_column = self.find_source_column(source_name) - self.fields.append(PartitionField(source_column.field_id, - self.__next_field_id(), - name, - Transforms.month(source_column.type))) + self.check_redundant_and_add_field(source_column.field_id, + target_name, + Transforms.month(source_column.type)) return self - def day(self, source_name): - name = "{}_day".format(source_name) - self.check_and_add_partition_name(name) + def day(self, source_name, target_name=None): + if target_name is None: + target_name = "{}_day".format(source_name) + + self.check_and_add_partition_name(target_name) source_column = self.find_source_column(source_name) - self.fields.append(PartitionField(source_column.field_id, - self.__next_field_id(), - name, - Transforms.day(source_column.type))) + self.check_redundant_and_add_field(source_column.field_id, + target_name, + Transforms.day(source_column.type)) return self - def hour(self, source_name): - name = "{}_hour".format(source_name) - self.check_and_add_partition_name(name) + def hour(self, source_name, target_name=None): + if target_name is None: + target_name = "{}_hour".format(source_name) + + self.check_and_add_partition_name(target_name) source_column = self.find_source_column(source_name) - self.fields.append(PartitionField(source_column.field_id, - self.__next_field_id(), - name, - Transforms.hour(source_column.type))) + self.check_redundant_and_add_field(source_column.field_id, + target_name, + Transforms.hour(source_column.type)) return self - def bucket(self, source_name, num_buckets): - name = "{}_bucket".format(source_name) - self.check_and_add_partition_name(name) + def bucket(self, source_name, num_buckets, target_name=None): + if target_name is None: + target_name = "{}_bucket".format(source_name) + + self.check_and_add_partition_name(target_name) source_column = self.find_source_column(source_name) self.fields.append(PartitionField(source_column.field_id, self.__next_field_id(), - name, + target_name, Transforms.bucket(source_column.type, num_buckets))) return self - def truncate(self, source_name, width): - name = "{}_truncate".format(source_name) - self.check_and_add_partition_name(name) + def truncate(self, source_name, width, target_name=None): + if target_name is None: + target_name = "{}_truncate".format(source_name) + + self.check_and_add_partition_name(target_name) source_column = self.find_source_column(source_name) self.fields.append(PartitionField(source_column.field_id, self.__next_field_id(), - name, + target_name, Transforms.truncate(source_column.type, width))) return self @@ -302,17 +333,16 @@ def add_without_field_id(self, source_id, name, transform): return self.add(source_id, self.__next_field_id(), name, transform) def add(self, source_id: int, field_id: int, name: str, transform: str) -> "PartitionSpecBuilder": - self.check_and_add_partition_name(name) column = self.schema.find_field(source_id) if column is None: - raise RuntimeError("Cannot find source column: %s" % source_id) + raise ValueError("Cannot find source column: %s" % source_id) transform_obj = Transforms.from_string(column.type, transform) - field = PartitionField(source_id, - field_id, - name, - transform_obj) - self.fields.append(field) + self.check_and_add_partition_name(name, source_id) + self.fields.append(PartitionField(source_id, + field_id, + name, + transform_obj)) self.last_assigned_field_id = max(self.last_assigned_field_id, field_id) return self diff --git a/python/iceberg/api/transforms/dates.py b/python/iceberg/api/transforms/dates.py index 474b986f696d..dfc6b9a2cd02 100644 --- a/python/iceberg/api/transforms/dates.py +++ b/python/iceberg/api/transforms/dates.py @@ -74,3 +74,14 @@ def to_human_string(self, value): def __str__(self): return self.name + + def dedup_name(self): + return "time" + + def __eq__(self, other): + if id(self) == id(other): + return True + if other is None or not isinstance(other, Dates): + return False + + return self.granularity == other.granularity and self.name == other.name diff --git a/python/iceberg/api/transforms/timestamps.py b/python/iceberg/api/transforms/timestamps.py index 25c4439bc179..ca38a1c3bedc 100644 --- a/python/iceberg/api/transforms/timestamps.py +++ b/python/iceberg/api/transforms/timestamps.py @@ -70,3 +70,14 @@ def to_human_string(self, value): def __str__(self): return self.name + + def dedup_name(self): + return "time" + + def __eq__(self, other): + if id(self) == id(other): + return True + if other is None or not isinstance(other, Timestamps): + return False + + return self.granularity == other.granularity and self.name == other.name diff --git a/python/iceberg/api/transforms/transform.py b/python/iceberg/api/transforms/transform.py index 776b0a507b3a..0ccbf2e3c25e 100644 --- a/python/iceberg/api/transforms/transform.py +++ b/python/iceberg/api/transforms/transform.py @@ -38,3 +38,6 @@ def project_strict(self, name, predicate): def to_human_string(self, value): return str(value) + + def dedup_name(self): + return self.__str__() diff --git a/python/tests/api/test_partition_spec.py b/python/tests/api/test_partition_spec.py new file mode 100644 index 000000000000..93cee149600c --- /dev/null +++ b/python/tests/api/test_partition_spec.py @@ -0,0 +1,84 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +from iceberg.api import PartitionSpec +from iceberg.api.schema import Schema +from iceberg.api.types import (BinaryType, + DateType, + DecimalType, + FixedType, + IntegerType, + LongType, + NestedField, + StringType, + TimestampType, + TimeType, + UUIDType) +from tests.api.test_helpers import TestHelpers + + +class TestConversions(unittest.TestCase): + + def test_transforms(self): + schema = Schema(NestedField.required(1, "i", IntegerType.get()), + NestedField.required(2, "l", LongType.get()), + NestedField.required(3, "d", DateType.get()), + NestedField.required(4, "t", TimeType.get()), + NestedField.required(5, "ts", TimestampType.without_timezone()), + NestedField.required(6, "dec", DecimalType.of(9, 2)), + NestedField.required(7, "s", StringType.get()), + NestedField.required(8, "u", UUIDType.get()), + NestedField.required(9, "f", FixedType.of_length(3)), + NestedField.required(10, "b", BinaryType.get())) + specs = [PartitionSpec.builder_for(schema).identity("i").build(), + PartitionSpec.builder_for(schema).identity("l").build(), + PartitionSpec.builder_for(schema).identity("d").build(), + PartitionSpec.builder_for(schema).identity("t").build(), + PartitionSpec.builder_for(schema).identity("ts").build(), + PartitionSpec.builder_for(schema).identity("dec").build(), + PartitionSpec.builder_for(schema).identity("s").build(), + PartitionSpec.builder_for(schema).identity("u").build(), + PartitionSpec.builder_for(schema).identity("f").build(), + PartitionSpec.builder_for(schema).identity("b").build(), + PartitionSpec.builder_for(schema).bucket("i", 128).build(), + PartitionSpec.builder_for(schema).bucket("l", 128).build(), + PartitionSpec.builder_for(schema).bucket("d", 128).build(), + PartitionSpec.builder_for(schema).bucket("t", 128).build(), + PartitionSpec.builder_for(schema).bucket("ts", 128).build(), + PartitionSpec.builder_for(schema).bucket("dec", 128).build(), + PartitionSpec.builder_for(schema).bucket("s", 128).build(), + # todo support them + # PartitionSpec.builder_for(schema).bucket("u", 128).build(), + # PartitionSpec.builder_for(schema).bucket("f", 128).build(), + # PartitionSpec.builder_for(schema).bucket("b", 128).build(), + PartitionSpec.builder_for(schema).year("d").build(), + PartitionSpec.builder_for(schema).month("d").build(), + PartitionSpec.builder_for(schema).day("d").build(), + PartitionSpec.builder_for(schema).year("ts").build(), + PartitionSpec.builder_for(schema).month("ts").build(), + PartitionSpec.builder_for(schema).day("ts").build(), + PartitionSpec.builder_for(schema).hour("ts").build(), + PartitionSpec.builder_for(schema).truncate("i", 10).build(), + PartitionSpec.builder_for(schema).truncate("l", 10).build(), + PartitionSpec.builder_for(schema).truncate("dec", 10).build(), + PartitionSpec.builder_for(schema).truncate("s", 10).build(), + # todo support them + # PartitionSpec.builder_for(schema).add_without_field_id(6, "dec_unsupported", "unsupported").build(), + # PartitionSpec.builder_for(schema).add(6, 1111, "dec_unsupported", "unsupported").build(), + ] + + for spec in specs: + self.assertEqual(spec, TestHelpers.round_trip_serialize(spec))