From eef72f3db59b6b1ac82e173997c81dbf6e339a73 Mon Sep 17 00:00:00 2001 From: Adrian Qin <147659252+jqin61@users.noreply.github.com> Date: Thu, 28 Mar 2024 21:05:38 +0000 Subject: [PATCH 01/15] partitioned append on identity transform --- pyiceberg/io/pyarrow.py | 86 ++-- pyiceberg/manifest.py | 6 - pyiceberg/table/__init__.py | 180 ++++++- tests/catalog/test_sql.py | 3 - tests/integration/test_partitioned_write.py | 534 ++++++++++++++++++++ tests/integration/test_reads.py | 1 - tests/integration/test_writes.py | 1 - 7 files changed, 740 insertions(+), 71 deletions(-) create mode 100644 tests/integration/test_partitioned_write.py diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index baefb3564f..6d7194be73 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -1761,54 +1761,46 @@ def data_file_statistics_from_parquet_metadata( def write_file(io: FileIO, table_metadata: TableMetadata, tasks: Iterator[WriteTask]) -> Iterator[DataFile]: - task = next(tasks) - - try: - _ = next(tasks) - # If there are more tasks, raise an exception - raise NotImplementedError("Only unpartitioned writes are supported: https://github.com/apache/iceberg-python/issues/208") - except StopIteration: - pass - - parquet_writer_kwargs = _get_parquet_writer_kwargs(table_metadata.properties) - - file_path = f'{table_metadata.location}/data/{task.generate_data_file_filename("parquet")}' - schema = table_metadata.schema() - arrow_file_schema = schema.as_arrow() - - fo = io.new_output(file_path) - row_group_size = PropertyUtil.property_as_int( - properties=table_metadata.properties, - property_name=TableProperties.PARQUET_ROW_GROUP_SIZE_BYTES, - default=TableProperties.PARQUET_ROW_GROUP_SIZE_BYTES_DEFAULT, - ) - with fo.create(overwrite=True) as fos: - with pq.ParquetWriter(fos, schema=arrow_file_schema, **parquet_writer_kwargs) as writer: - writer.write_table(task.df, row_group_size=row_group_size) - - statistics = data_file_statistics_from_parquet_metadata( - parquet_metadata=writer.writer.metadata, - stats_columns=compute_statistics_plan(schema, table_metadata.properties), - parquet_column_mapping=parquet_path_to_id_mapping(schema), - ) - data_file = DataFile( - content=DataFileContent.DATA, - file_path=file_path, - file_format=FileFormat.PARQUET, - partition=Record(), - file_size_in_bytes=len(fo), - # After this has been fixed: - # https://github.com/apache/iceberg-python/issues/271 - # sort_order_id=task.sort_order_id, - sort_order_id=None, - # Just copy these from the table for now - spec_id=table_metadata.default_spec_id, - equality_ids=None, - key_metadata=None, - **statistics.to_serialized_dict(), - ) + for task in tasks: + parquet_writer_kwargs = _get_parquet_writer_kwargs(table_metadata.properties) + + file_path = f'{table_metadata.location}/data/{task.generate_data_file_path("parquet")}' # generate_data_file_filename + schema = table_metadata.schema() + arrow_file_schema = schema.as_arrow() + + fo = io.new_output(file_path) + row_group_size = PropertyUtil.property_as_int( + properties=table_metadata.properties, + property_name=TableProperties.PARQUET_ROW_GROUP_SIZE_BYTES, + default=TableProperties.PARQUET_ROW_GROUP_SIZE_BYTES_DEFAULT, + ) + with fo.create(overwrite=True) as fos: + with pq.ParquetWriter(fos, schema=arrow_file_schema, **parquet_writer_kwargs) as writer: + writer.write_table(task.df, row_group_size=row_group_size) - return iter([data_file]) + statistics = data_file_statistics_from_parquet_metadata( + parquet_metadata=writer.writer.metadata, + stats_columns=compute_statistics_plan(schema, table_metadata.properties), + parquet_column_mapping=parquet_path_to_id_mapping(schema), + ) + data_file = DataFile( + content=DataFileContent.DATA, + file_path=file_path, + file_format=FileFormat.PARQUET, + partition=task.partition_key.partition if task.partition_key else Record(), + file_size_in_bytes=len(fo), + # After this has been fixed: + # https://github.com/apache/iceberg-python/issues/271 + # sort_order_id=task.sort_order_id, + sort_order_id=None, + # Just copy these from the table for now + spec_id=table_metadata.default_spec_id, + equality_ids=None, + key_metadata=None, + **statistics.to_serialized_dict(), + ) + + yield data_file def parquet_files_to_data_files(io: FileIO, table_metadata: TableMetadata, file_paths: Iterator[str]) -> Iterator[DataFile]: diff --git a/pyiceberg/manifest.py b/pyiceberg/manifest.py index 03dc3199bf..f982629f93 100644 --- a/pyiceberg/manifest.py +++ b/pyiceberg/manifest.py @@ -41,7 +41,6 @@ from pyiceberg.types import ( BinaryType, BooleanType, - DateType, IcebergType, IntegerType, ListType, @@ -51,8 +50,6 @@ PrimitiveType, StringType, StructType, - TimestampType, - TimestamptzType, TimeType, ) @@ -289,10 +286,7 @@ def partition_field_to_data_file_partition_field(partition_field_type: IcebergTy @partition_field_to_data_file_partition_field.register(LongType) -@partition_field_to_data_file_partition_field.register(DateType) @partition_field_to_data_file_partition_field.register(TimeType) -@partition_field_to_data_file_partition_field.register(TimestampType) -@partition_field_to_data_file_partition_field.register(TimestamptzType) def _(partition_field_type: PrimitiveType) -> IntegerType: return IntegerType() diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 2ad1f7fe81..95143d7d95 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -16,13 +16,13 @@ # under the License. from __future__ import annotations -import datetime import itertools import uuid import warnings from abc import ABC, abstractmethod from copy import copy from dataclasses import dataclass +from datetime import datetime from enum import Enum from functools import cached_property, singledispatch from itertools import chain @@ -72,6 +72,8 @@ INITIAL_PARTITION_SPEC_ID, PARTITION_FIELD_ID_START, PartitionField, + PartitionFieldValue, + PartitionKey, PartitionSpec, _PartitionNameGenerator, _visit_partition_field, @@ -708,7 +710,7 @@ def _(update: SetSnapshotRefUpdate, base_metadata: TableMetadata, context: _Tabl if update.ref_name == MAIN_BRANCH: metadata_updates["current_snapshot_id"] = snapshot_ref.snapshot_id if "last_updated_ms" not in metadata_updates: - metadata_updates["last_updated_ms"] = datetime_to_millis(datetime.datetime.now().astimezone()) + metadata_updates["last_updated_ms"] = datetime_to_millis(datetime.now().astimezone()) metadata_updates["snapshot_log"] = base_metadata.snapshot_log + [ SnapshotLogEntry( @@ -1123,9 +1125,6 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT) if not isinstance(df, pa.Table): raise ValueError(f"Expected PyArrow table, got: {df}") - if len(self.spec().fields) > 0: - raise ValueError("Cannot write to partitioned tables") - _check_schema_compatible(self.schema(), other_schema=df.schema) # cast if the two schemas are compatible but not equal if self.schema().as_arrow() != df.schema: @@ -2487,15 +2486,27 @@ class WriteTask: write_uuid: uuid.UUID task_id: int df: pa.Table + schema: Schema sort_order_id: Optional[int] = None + partition_key: Optional[PartitionKey] = None - # Later to be extended with partition information + def generate_data_file_partition_path(self) -> str: + if self.partition_key is None: + raise ValueError("Cannot generate partition path based on non-partitioned WriteTask") + return self.partition_key.to_path() def generate_data_file_filename(self, extension: str) -> str: # Mimics the behavior in the Java API: # https://github.com/apache/iceberg/blob/a582968975dd30ff4917fbbe999f1be903efac02/core/src/main/java/org/apache/iceberg/io/OutputFileFactory.java#L92-L101 return f"00000-{self.task_id}-{self.write_uuid}.{extension}" + def generate_data_file_path(self, extension: str) -> str: + if self.partition_key: + file_path = f"{self.generate_data_file_partition_path()}/{self. generate_data_file_filename(extension)}" + return file_path + else: + return self.generate_data_file_filename(extension) + @dataclass(frozen=True) class AddFileTask: @@ -2523,15 +2534,34 @@ def _dataframe_to_data_files( """ from pyiceberg.io.pyarrow import write_file - if len([spec for spec in table_metadata.partition_specs if spec.spec_id != 0]) > 0: - raise ValueError("Cannot write to partitioned tables") - counter = itertools.count(0) write_uuid = write_uuid or uuid.uuid4() - # This is an iter, so we don't have to materialize everything every time - # This will be more relevant when we start doing partitioned writes - yield from write_file(io=io, table_metadata=table_metadata, tasks=iter([WriteTask(write_uuid, next(counter), df)])) + # this commented old line seems to be bug (not the correct way to check whether the table is partitioned), remember to highlight in pull request review + # if len([spec for spec in table_metadata.partition_specs if spec.spec_id != 0]) > 0: + + if any(len(spec.fields) > 0 for spec in table_metadata.partition_specs): + partitions = partition(table_metadata, df) + yield from write_file( + io=io, + table_metadata=table_metadata, + tasks=iter([ + WriteTask( + write_uuid=write_uuid, + task_id=next(counter), + df=partition.arrow_table_partition, + partition_key=partition.partition_key, + schema=table_metadata.schema(), + ) + for partition in partitions + ]), + ) + else: + yield from write_file( + io=io, + table_metadata=table_metadata, + tasks=iter([WriteTask(write_uuid=write_uuid, task_id=next(counter), df=df, schema=table_metadata.schema())]), + ) def _parquet_files_to_data_files(table_metadata: TableMetadata, file_paths: List[str], io: FileIO) -> Iterable[DataFile]: @@ -3083,7 +3113,7 @@ def snapshots(self) -> "pa.Table": additional_properties = None snapshots.append({ - 'committed_at': datetime.datetime.utcfromtimestamp(snapshot.timestamp_ms / 1000.0), + 'committed_at': datetime.utcfromtimestamp(snapshot.timestamp_ms / 1000.0), 'snapshot_id': snapshot.snapshot_id, 'parent_id': snapshot.parent_snapshot_id, 'operation': str(operation), @@ -3095,3 +3125,127 @@ def snapshots(self) -> "pa.Table": snapshots, schema=snapshots_schema, ) + + +@dataclass(frozen=True) +class TablePartition: + partition_key: PartitionKey + arrow_table_partition: pa.Table + + +def _get_partition_sort_order(partition_columns: list[str], reverse: bool = False) -> dict[str, Any]: + order = 'ascending' if not reverse else 'descending' + null_placement = 'at_start' if reverse else 'at_end' + return {'sort_keys': [(column_name, order) for column_name in partition_columns], 'null_placement': null_placement} + + +def group_by_partition_scheme( + iceberg_table_metadata: TableMetadata, arrow_table: pa.Table, partition_columns: list[str] +) -> pa.Table: + """Given a table sort it by current partition scheme with all transform functions supported.""" + from pyiceberg.transforms import IdentityTransform + + supported = {IdentityTransform} + if not all( + type(field.transform) in supported for field in iceberg_table_metadata.spec().fields if field in partition_columns + ): + raise ValueError( + f"Not all transforms are supported, get: {[transform in supported for transform in iceberg_table_metadata.spec().fields]}." + ) + + # only works for identity + sort_options = _get_partition_sort_order(partition_columns, reverse=False) + sorted_arrow_table = arrow_table.sort_by(sorting=sort_options['sort_keys'], null_placement=sort_options['null_placement']) + return sorted_arrow_table + + +def get_partition_columns(iceberg_table_metadata: TableMetadata, arrow_table: pa.Table) -> list[str]: + arrow_table_cols = set(arrow_table.column_names) + partition_cols = [] + for transform_field in iceberg_table_metadata.spec().fields: + column_name = iceberg_table_metadata.schema().find_column_name(transform_field.source_id) + if not column_name: + raise ValueError(f"{transform_field=} could not be found in {iceberg_table_metadata.schema()}.") + if column_name not in arrow_table_cols: + continue + partition_cols.append(column_name) + return partition_cols + + +def _get_table_partitions( + arrow_table: pa.Table, + partition_spec: PartitionSpec, + schema: Schema, + slice_instructions: list[dict[str, Any]], +) -> list[TablePartition]: + sorted_slice_instructions = sorted(slice_instructions, key=lambda x: x['offset']) + + partition_fields = partition_spec.fields + + offsets = [inst["offset"] for inst in sorted_slice_instructions] + projected_and_filtered = { + partition_field.source_id: arrow_table[schema.find_field(name_or_id=partition_field.source_id).name] + .take(offsets) + .to_pylist() + for partition_field in partition_fields + } + + table_partitions = [] + for inst in sorted_slice_instructions: + partition_slice = arrow_table.slice(**inst) + fieldvalues = [ + PartitionFieldValue(partition_field, projected_and_filtered[partition_field.source_id][inst["offset"]]) + for partition_field in partition_fields + ] + partition_key = PartitionKey(raw_partition_field_values=fieldvalues, partition_spec=partition_spec, schema=schema) + table_partitions.append(TablePartition(partition_key=partition_key, arrow_table_partition=partition_slice)) + + return table_partitions + + +def partition(iceberg_table_metadata: TableMetadata, arrow_table: pa.Table) -> Iterable[TablePartition]: + """Based on the iceberg table partition spec, slice the arrow table into partitions with their keys. + + Example: + Input: + An arrow table with partition key of ['n_legs', 'year'] and with data of + {'year': [2020, 2022, 2022, 2021, 2022, 2022, 2022, 2019, 2021], + 'n_legs': [2, 2, 2, 4, 4, 4, 4, 5, 100], + 'animal': ["Flamingo", "Parrot", "Parrot", "Dog", "Horse", "Horse", "Horse","Brittle stars", "Centipede"]}. + The algrithm: + Firstly we group the rows into partitions by sorting with sort order [('n_legs', 'descending'), ('year', 'descending')] + and null_placement of "at_end". + This gives the same table as raw input. + Then we sort_indices using reverse order of [('n_legs', 'descending'), ('year', 'descending')] + and null_placement : "at_start". + This gives: + [8, 7, 4, 5, 6, 3, 1, 2, 0] + Based on this we get partition groups of indices: + [{'offset': 8, 'length': 1}, {'offset': 7, 'length': 1}, {'offset': 4, 'length': 3}, {'offset': 3, 'length': 1}, {'offset': 1, 'length': 2}, {'offset': 0, 'length': 1}] + We then retrieve the partition keys by offsets. + And slice the arrow table by offsets and lengths of each partition. + """ + import pyarrow as pa + + partition_columns = get_partition_columns(iceberg_table_metadata, arrow_table) + arrow_table = group_by_partition_scheme(iceberg_table_metadata, arrow_table, partition_columns) + + reversing_sort_order_options = _get_partition_sort_order(partition_columns, reverse=True) + reversed_indices = pa.compute.sort_indices(arrow_table, **reversing_sort_order_options).to_pylist() + + slice_instructions: list[dict[str, Any]] = [] + last = len(reversed_indices) + reversed_indices_size = len(reversed_indices) + ptr = 0 + while ptr < reversed_indices_size: + group_size = last - reversed_indices[ptr] + offset = reversed_indices[ptr] + slice_instructions.append({"offset": offset, "length": group_size}) + last = reversed_indices[ptr] + ptr = ptr + group_size + + table_partitions: list[TablePartition] = _get_table_partitions( + arrow_table, iceberg_table_metadata.spec(), iceberg_table_metadata.schema(), slice_instructions + ) + + return table_partitions diff --git a/tests/catalog/test_sql.py b/tests/catalog/test_sql.py index b20f617e32..01cfa2255e 100644 --- a/tests/catalog/test_sql.py +++ b/tests/catalog/test_sql.py @@ -220,9 +220,6 @@ def test_write_pyarrow_schema(catalog: SqlCatalog, random_identifier: Identifier database_name, _table_name = random_identifier catalog.create_namespace(database_name) table = catalog.create_table(random_identifier, pyarrow_table.schema) - print(pyarrow_table.schema) - print(table.schema().as_struct()) - print() table.overwrite(pyarrow_table) diff --git a/tests/integration/test_partitioned_write.py b/tests/integration/test_partitioned_write.py new file mode 100644 index 0000000000..242d4b4966 --- /dev/null +++ b/tests/integration/test_partitioned_write.py @@ -0,0 +1,534 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint:disable=redefined-outer-name +import uuid +from datetime import date, datetime + +import pyarrow as pa +import pytest +import pytz +from pyspark.sql import SparkSession + +from pyiceberg.catalog import Catalog, load_catalog +from pyiceberg.exceptions import NamespaceAlreadyExistsError, NoSuchTableError +from pyiceberg.partitioning import PartitionField, PartitionSpec +from pyiceberg.schema import Schema +from pyiceberg.transforms import IdentityTransform +from pyiceberg.types import ( + BinaryType, + BooleanType, + DateType, + DoubleType, + FixedType, + FloatType, + IntegerType, + LongType, + NestedField, + StringType, + TimestampType, + TimestamptzType, +) + + +@pytest.fixture() +def catalog() -> Catalog: + catalog = load_catalog( + "local", + **{ + "type": "rest", + "uri": "http://localhost:8181", + "s3.endpoint": "http://localhost:9000", + "s3.access-key-id": "admin", + "s3.secret-access-key": "password", + }, + ) + + try: + catalog.create_namespace("default") + except NamespaceAlreadyExistsError: + pass + + return catalog + + +TEST_DATA_WITH_NULL = { + 'bool': [False, None, True], + 'string': ['a', None, 'z'], + # Go over the 16 bytes to kick in truncation + 'string_long': ['a' * 22, None, 'z' * 22], + 'int': [1, None, 9], + 'long': [1, None, 9], + 'float': [0.0, None, 0.9], + 'double': [0.0, None, 0.9], + 'timestamp': [datetime(2023, 1, 1, 19, 25, 00), None, datetime(2023, 3, 1, 19, 25, 00)], + 'timestamptz': [ + datetime(2023, 1, 1, 19, 25, 00, tzinfo=pytz.timezone('America/New_York')), + None, + datetime(2023, 3, 1, 19, 25, 00, tzinfo=pytz.timezone('America/New_York')), + ], + 'date': [date(2023, 1, 1), None, date(2023, 3, 1)], + # Not supported by Spark + # 'time': [time(1, 22, 0), None, time(19, 25, 0)], + # Not natively supported by Arrow + # 'uuid': [uuid.UUID('00000000-0000-0000-0000-000000000000').bytes, None, uuid.UUID('11111111-1111-1111-1111-111111111111').bytes], + 'binary': [b'\01', None, b'\22'], + 'fixed': [ + uuid.UUID('00000000-0000-0000-0000-000000000000').bytes, + None, + uuid.UUID('11111111-1111-1111-1111-111111111111').bytes, + ], +} + + +TABLE_SCHEMA = Schema( + NestedField(field_id=1, name="bool", field_type=BooleanType(), required=False), + NestedField(field_id=2, name="string", field_type=StringType(), required=False), + NestedField(field_id=3, name="string_long", field_type=StringType(), required=False), + NestedField(field_id=4, name="int", field_type=IntegerType(), required=False), + NestedField(field_id=5, name="long", field_type=LongType(), required=False), + NestedField(field_id=6, name="float", field_type=FloatType(), required=False), + NestedField(field_id=7, name="double", field_type=DoubleType(), required=False), + NestedField(field_id=8, name="timestamp", field_type=TimestampType(), required=False), + NestedField(field_id=9, name="timestamptz", field_type=TimestamptzType(), required=False), + NestedField(field_id=10, name="date", field_type=DateType(), required=False), + # NestedField(field_id=11, name="time", field_type=TimeType(), required=False), + # NestedField(field_id=12, name="uuid", field_type=UuidType(), required=False), + NestedField(field_id=11, name="binary", field_type=BinaryType(), required=False), + NestedField(field_id=12, name="fixed", field_type=FixedType(16), required=False), +) + + +@pytest.fixture(scope="session") +def session_catalog() -> Catalog: + return load_catalog( + "local", + **{ + "type": "rest", + "uri": "http://localhost:8181", + "s3.endpoint": "http://localhost:9000", + "s3.access-key-id": "admin", + "s3.secret-access-key": "password", + }, + ) + + +@pytest.fixture(scope="session") +def arrow_table_with_null() -> pa.Table: + """PyArrow table with all kinds of columns""" + pa_schema = pa.schema([ + ("bool", pa.bool_()), + ("string", pa.string()), + ("string_long", pa.string()), + ("int", pa.int32()), + ("long", pa.int64()), + ("float", pa.float32()), + ("double", pa.float64()), + ("timestamp", pa.timestamp(unit="us")), + ("timestamptz", pa.timestamp(unit="us", tz="UTC")), + ("date", pa.date32()), + # Not supported by Spark + # ("time", pa.time64("us")), + # Not natively supported by Arrow + # ("uuid", pa.fixed(16)), + ("binary", pa.large_binary()), + ("fixed", pa.binary(16)), + ]) + return pa.Table.from_pydict(TEST_DATA_WITH_NULL, schema=pa_schema) + + +@pytest.fixture(scope="session", autouse=True) +def table_v1_with_null_partitioned(session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None: + partition_cols = [ + 'int', + 'bool', + 'string', + "string_long", + "long", + "float", + "double", + "date", + "timestamptz", + "timestamp", + "binary", + ] + for partition_col in partition_cols: + identifier = f"default.arrow_table_v1_with_null_partitioned_on_col_{partition_col}" + + try: + session_catalog.drop_table(identifier=identifier) + except NoSuchTableError: + pass + nested_field = TABLE_SCHEMA.find_field(partition_col) + source_id = nested_field.field_id + tbl = session_catalog.create_table( + identifier=identifier, + schema=TABLE_SCHEMA, + partition_spec=PartitionSpec( + PartitionField(source_id=source_id, field_id=1001, transform=IdentityTransform(), name=partition_col) + ), + properties={'format-version': '1'}, + ) + tbl.append(arrow_table_with_null) + + assert tbl.format_version == 1, f"Expected v1, got: v{tbl.format_version}" + + +@pytest.fixture(scope="session", autouse=True) +def table_v1_appended_with_null_partitioned(session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None: + partition_cols = [ + 'int', + 'bool', + 'string', + "string_long", + "long", + "float", + "double", + "date", + "timestamptz", + "timestamp", + "binary", + ] + for partition_col in partition_cols: + identifier = f"default.arrow_table_v1_appended_with_null_partitioned_on_col_{partition_col}" + + try: + session_catalog.drop_table(identifier=identifier) + except NoSuchTableError: + pass + + nested_field = TABLE_SCHEMA.find_field(partition_col) + source_id = nested_field.field_id + tbl = session_catalog.create_table( + identifier=identifier, + schema=TABLE_SCHEMA, + partition_spec=PartitionSpec( + PartitionField(source_id=source_id, field_id=1001, transform=IdentityTransform(), name=partition_col) + ), # name has to be the same for identity transform + properties={'format-version': '1'}, + ) + + for _ in range(2): + tbl.append(arrow_table_with_null) + assert tbl.format_version == 1, f"Expected v1, got: v{tbl.format_version}" + + +@pytest.fixture(scope="session", autouse=True) +def table_v2_with_null_partitioned(session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None: + partition_cols = [ + 'int', + 'bool', + 'string', + "string_long", + "long", + "float", + "double", + "date", + "timestamptz", + "timestamp", + "binary", + ] + for partition_col in partition_cols: + identifier = f"default.arrow_table_v2_with_null_partitioned_on_col_{partition_col}" + + try: + session_catalog.drop_table(identifier=identifier) + except NoSuchTableError: + pass + nested_field = TABLE_SCHEMA.find_field(partition_col) + source_id = nested_field.field_id + + tbl = session_catalog.create_table( + identifier=identifier, + schema=TABLE_SCHEMA, + partition_spec=PartitionSpec( + PartitionField(source_id=source_id, field_id=1001, transform=IdentityTransform(), name=partition_col) + ), + properties={'format-version': '2'}, + ) + tbl.append(arrow_table_with_null) + + assert tbl.format_version == 2, f"Expected v2, got: v{tbl.format_version}" + + +@pytest.fixture(scope="session", autouse=True) +def table_v2_appended_with_null_partitioned(session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None: + partition_cols = [ + 'int', + 'bool', + 'string', + "string_long", + "long", + "float", + "double", + "date", + "timestamptz", + "timestamp", + "binary", + ] + for partition_col in partition_cols: + identifier = f"default.arrow_table_v2_appended_with_null_partitioned_on_col_{partition_col}" + + try: + session_catalog.drop_table(identifier=identifier) + except NoSuchTableError: + pass + + nested_field = TABLE_SCHEMA.find_field(partition_col) + source_id = nested_field.field_id + tbl = session_catalog.create_table( + identifier=identifier, + schema=TABLE_SCHEMA, + partition_spec=PartitionSpec( + PartitionField(source_id=source_id, field_id=1001, transform=IdentityTransform(), name=partition_col) + ), + properties={'format-version': '2'}, + ) + + for _ in range(2): + tbl.append(arrow_table_with_null) + assert tbl.format_version == 2, f"Expected v2, got: v{tbl.format_version}" + + +@pytest.fixture(scope="session", autouse=True) +def table_v1_v2_appended_with_null(session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None: + partition_cols = [ + 'int', + 'bool', + 'string', + "string_long", + "long", + "float", + "double", + "date", + "timestamptz", + "timestamp", + "binary", + ] + for partition_col in partition_cols: + identifier = f"default.arrow_table_v1_v2_appended_with_null_partitioned_on_col_{partition_col}" + + try: + session_catalog.drop_table(identifier=identifier) + except NoSuchTableError: + pass + + nested_field = TABLE_SCHEMA.find_field(partition_col) + source_id = nested_field.field_id + tbl = session_catalog.create_table( + identifier=identifier, + schema=TABLE_SCHEMA, + partition_spec=PartitionSpec( + PartitionField(source_id=source_id, field_id=1001, transform=IdentityTransform(), name=partition_col) + ), + properties={'format-version': '1'}, + ) + tbl.append(arrow_table_with_null) + + assert tbl.format_version == 1, f"Expected v1, got: v{tbl.format_version}" + + with tbl.transaction() as tx: + tx.upgrade_table_version(format_version=2) + + tbl.append(arrow_table_with_null) + + assert tbl.format_version == 2, f"Expected v2, got: v{tbl.format_version}" + + +@pytest.mark.integration +@pytest.mark.parametrize( + "part_col", ['int', 'bool', 'string', "string_long", "long", "float", "double", "date", 'timestamp', 'timestamptz'] +) +@pytest.mark.parametrize("format_version", [1, 2]) +def test_query_filter_null_partitioned(spark: SparkSession, part_col: str, format_version: int) -> None: + identifier = f"default.arrow_table_v{format_version}_with_null_partitioned_on_col_{part_col}" + df = spark.table(identifier) + df.show(20, False) + for col in TEST_DATA_WITH_NULL.keys(): + assert df.where(f"{col} is not null").count() == 2, f"Expected 2 rows for {col}" + + +@pytest.mark.integration +@pytest.mark.parametrize( + "part_col", ['int', 'bool', 'string', "string_long", "long", "float", "double", "date", "timestamptz", "timestamp", "binary"] +) +@pytest.mark.parametrize("format_version", [1, 2]) +def test_query_filter_appended_null_partitioned(spark: SparkSession, part_col: str, format_version: int) -> None: + identifier = f"default.arrow_table_v{format_version}_appended_with_null_partitioned_on_col_{part_col}" + df = spark.table(identifier) + for col in TEST_DATA_WITH_NULL.keys(): + df = spark.table(identifier) + assert df.where(f"{col} is not null").count() == 4, f"Expected 6 rows for {col}" + + +@pytest.fixture(scope="session") +def spark() -> SparkSession: + import importlib.metadata + import os + + spark_version = ".".join(importlib.metadata.version("pyspark").split(".")[:2]) + scala_version = "2.12" + iceberg_version = "1.4.3" + + os.environ["PYSPARK_SUBMIT_ARGS"] = ( + f"--packages org.apache.iceberg:iceberg-spark-runtime-{spark_version}_{scala_version}:{iceberg_version}," + f"org.apache.iceberg:iceberg-aws-bundle:{iceberg_version} pyspark-shell" + ) + os.environ["AWS_REGION"] = "us-east-1" + os.environ["AWS_ACCESS_KEY_ID"] = "admin" + os.environ["AWS_SECRET_ACCESS_KEY"] = "password" + + spark = ( + SparkSession.builder.appName("PyIceberg integration test") + .config("spark.sql.extensions", "org.apache.iceberg.spark.extensions.IcebergSparkSessionExtensions") + .config("spark.sql.catalog.integration", "org.apache.iceberg.spark.SparkCatalog") + .config("spark.sql.catalog.integration.catalog-impl", "org.apache.iceberg.rest.RESTCatalog") + .config("spark.sql.catalog.integration.uri", "http://localhost:8181") + .config("spark.sql.catalog.integration.io-impl", "org.apache.iceberg.aws.s3.S3FileIO") + .config("spark.sql.catalog.integration.warehouse", "s3://warehouse/wh/") + .config("spark.sql.catalog.integration.s3.endpoint", "http://localhost:9000") + .config("spark.sql.catalog.integration.s3.path-style-access", "true") + .config("spark.sql.defaultCatalog", "integration") + .getOrCreate() + ) + + return spark + + +@pytest.mark.integration +@pytest.mark.parametrize( + "part_col", ['int', 'bool', 'string', "string_long", "long", "float", "double", "date", "timestamptz", "timestamp", "binary"] +) +def test_query_filter_v1_v2_append_null(spark: SparkSession, part_col: str) -> None: + identifier = f"default.arrow_table_v1_v2_appended_with_null_partitioned_on_col_{part_col}" + df = spark.table(identifier) + for col in TEST_DATA_WITH_NULL.keys(): + df = spark.table(identifier) + assert df.where(f"{col} is not null").count() == 4, f"Expected 4 row for {col}" + + +@pytest.mark.newyork +def test_summaries_with_null(spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None: + identifier = "default.arrow_table_summaries" + + try: + session_catalog.drop_table(identifier=identifier) + except NoSuchTableError: + pass + tbl = session_catalog.create_table( + identifier=identifier, + schema=TABLE_SCHEMA, + partition_spec=PartitionSpec(PartitionField(source_id=4, field_id=1001, transform=IdentityTransform(), name="int")), + properties={'format-version': '2'}, + ) + + tbl.append(arrow_table_with_null) + tbl.append(arrow_table_with_null) + + rows = spark.sql( + f""" + SELECT operation, summary + FROM {identifier}.snapshots + ORDER BY committed_at ASC + """ + ).collect() + + operations = [row.operation for row in rows] + assert operations == ['append', 'append'] + + summaries = [row.summary for row in rows] + assert summaries[0] == { + 'changed-partition-count': '3', + 'added-data-files': '3', + 'added-files-size': '15029', + 'added-records': '3', + 'total-data-files': '3', + 'total-delete-files': '0', + 'total-equality-deletes': '0', + 'total-files-size': '15029', + 'total-position-deletes': '0', + 'total-records': '3', + } + + assert summaries[1] == { + 'changed-partition-count': '3', + 'added-data-files': '3', + 'added-files-size': '15029', + 'added-records': '3', + 'total-data-files': '6', + 'total-delete-files': '0', + 'total-equality-deletes': '0', + 'total-files-size': '30058', + 'total-position-deletes': '0', + 'total-records': '6', + } + + +@pytest.mark.integration +def test_data_files_with_table_partitioned_with_null( + spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table +) -> None: + identifier = "default.arrow_data_files" + + try: + session_catalog.drop_table(identifier=identifier) + except NoSuchTableError: + pass + tbl = session_catalog.create_table( + identifier=identifier, + schema=TABLE_SCHEMA, + partition_spec=PartitionSpec(PartitionField(source_id=4, field_id=1001, transform=IdentityTransform(), name="int")), + properties={'format-version': '1'}, + ) + + tbl.append(arrow_table_with_null) + tbl.append(arrow_table_with_null) + + # added_data_files_count, existing_data_files_count, deleted_data_files_count + rows = spark.sql( + f""" + SELECT added_data_files_count, existing_data_files_count, deleted_data_files_count + FROM {identifier}.all_manifests + """ + ).collect() + + assert [row.added_data_files_count for row in rows] == [3, 3, 3] + assert [row.existing_data_files_count for row in rows] == [ + 0, + 0, + 0, + ] + assert [row.deleted_data_files_count for row in rows] == [0, 0, 0] + + +@pytest.mark.integration +def test_invalid_arguments(spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None: + identifier = "default.arrow_data_files" + + try: + session_catalog.drop_table(identifier=identifier) + except NoSuchTableError: + pass + + tbl = session_catalog.create_table( + identifier=identifier, + schema=TABLE_SCHEMA, + partition_spec=PartitionSpec(PartitionField(source_id=4, field_id=1001, transform=IdentityTransform(), name="int")), + properties={'format-version': '1'}, + ) + + with pytest.raises(ValueError, match="Expected PyArrow table, got: not a df"): + tbl.append("not a df") \ No newline at end of file diff --git a/tests/integration/test_reads.py b/tests/integration/test_reads.py index fdc13ae752..c670bc4846 100644 --- a/tests/integration/test_reads.py +++ b/tests/integration/test_reads.py @@ -274,7 +274,6 @@ def test_ray_nan_rewritten(catalog: Catalog) -> None: def test_ray_not_nan_count(catalog: Catalog) -> None: table_test_null_nan_rewritten = catalog.load_table("default.test_null_nan_rewritten") ray_dataset = table_test_null_nan_rewritten.scan(row_filter=NotNaN("col_numeric"), selected_fields=("idx",)).to_ray() - print(ray_dataset.take()) assert ray_dataset.count() == 2 diff --git a/tests/integration/test_writes.py b/tests/integration/test_writes.py index 87c33d651b..17f9931d6e 100644 --- a/tests/integration/test_writes.py +++ b/tests/integration/test_writes.py @@ -439,7 +439,6 @@ def test_write_parquet_other_properties( properties: Dict[str, Any], expected_kwargs: Dict[str, Any], ) -> None: - print(type(mocker)) identifier = "default.test_write_parquet_other_properties" # The properties we test cannot be checked on the resulting Parquet file, so we spy on the ParquetWriter call instead From 870e49bf8b38d238ca746ed20b92aef0ac4c1278 Mon Sep 17 00:00:00 2001 From: Adrian Qin <147659252+jqin61@users.noreply.github.com> Date: Fri, 29 Mar 2024 04:19:27 +0000 Subject: [PATCH 02/15] remove unnecessary fixture --- tests/integration/test_partitioned_write.py | 74 +-------------------- 1 file changed, 2 insertions(+), 72 deletions(-) diff --git a/tests/integration/test_partitioned_write.py b/tests/integration/test_partitioned_write.py index 7b872616d9..dbe5042dd2 100644 --- a/tests/integration/test_partitioned_write.py +++ b/tests/integration/test_partitioned_write.py @@ -23,8 +23,8 @@ import pytz from pyspark.sql import SparkSession -from pyiceberg.catalog import Catalog, load_catalog -from pyiceberg.exceptions import NamespaceAlreadyExistsError, NoSuchTableError +from pyiceberg.catalog import Catalog +from pyiceberg.exceptions import NoSuchTableError from pyiceberg.partitioning import PartitionField, PartitionSpec from pyiceberg.schema import Schema from pyiceberg.transforms import IdentityTransform @@ -43,28 +43,6 @@ TimestamptzType, ) - -@pytest.fixture() -def catalog() -> Catalog: - catalog = load_catalog( - "local", - **{ - "type": "rest", - "uri": "http://localhost:8181", - "s3.endpoint": "http://localhost:9000", - "s3.access-key-id": "admin", - "s3.secret-access-key": "password", - }, - ) - - try: - catalog.create_namespace("default") - except NamespaceAlreadyExistsError: - pass - - return catalog - - TEST_DATA_WITH_NULL = { 'bool': [False, None, True], 'string': ['a', None, 'z'], @@ -112,20 +90,6 @@ def catalog() -> Catalog: ) -@pytest.fixture(scope="session") -def session_catalog() -> Catalog: - return load_catalog( - "local", - **{ - "type": "rest", - "uri": "http://localhost:8181", - "s3.endpoint": "http://localhost:9000", - "s3.access-key-id": "admin", - "s3.secret-access-key": "password", - }, - ) - - @pytest.fixture(scope="session") def arrow_table_with_null() -> pa.Table: """PyArrow table with all kinds of columns""" @@ -374,40 +338,6 @@ def test_query_filter_appended_null_partitioned(spark: SparkSession, part_col: s assert df.where(f"{col} is not null").count() == 4, f"Expected 6 rows for {col}" -@pytest.fixture(scope="session") -def spark() -> SparkSession: - import importlib.metadata - import os - - spark_version = ".".join(importlib.metadata.version("pyspark").split(".")[:2]) - scala_version = "2.12" - iceberg_version = "1.4.3" - - os.environ["PYSPARK_SUBMIT_ARGS"] = ( - f"--packages org.apache.iceberg:iceberg-spark-runtime-{spark_version}_{scala_version}:{iceberg_version}," - f"org.apache.iceberg:iceberg-aws-bundle:{iceberg_version} pyspark-shell" - ) - os.environ["AWS_REGION"] = "us-east-1" - os.environ["AWS_ACCESS_KEY_ID"] = "admin" - os.environ["AWS_SECRET_ACCESS_KEY"] = "password" - - spark = ( - SparkSession.builder.appName("PyIceberg integration test") - .config("spark.sql.extensions", "org.apache.iceberg.spark.extensions.IcebergSparkSessionExtensions") - .config("spark.sql.catalog.integration", "org.apache.iceberg.spark.SparkCatalog") - .config("spark.sql.catalog.integration.catalog-impl", "org.apache.iceberg.rest.RESTCatalog") - .config("spark.sql.catalog.integration.uri", "http://localhost:8181") - .config("spark.sql.catalog.integration.io-impl", "org.apache.iceberg.aws.s3.S3FileIO") - .config("spark.sql.catalog.integration.warehouse", "s3://warehouse/wh/") - .config("spark.sql.catalog.integration.s3.endpoint", "http://localhost:9000") - .config("spark.sql.catalog.integration.s3.path-style-access", "true") - .config("spark.sql.defaultCatalog", "integration") - .getOrCreate() - ) - - return spark - - @pytest.mark.integration @pytest.mark.parametrize( "part_col", ['int', 'bool', 'string', "string_long", "long", "float", "double", "date", "timestamptz", "timestamp", "binary"] From 6020297268ba4ae3ae0bf1e53e14c8ad3830301e Mon Sep 17 00:00:00 2001 From: Adrian Qin <147659252+jqin61@users.noreply.github.com> Date: Fri, 29 Mar 2024 19:43:24 +0000 Subject: [PATCH 03/15] added null/empty table tests; fixed part of PR comments --- pyiceberg/table/__init__.py | 25 ++-- tests/integration/test_partitioned_write.py | 157 +++++++++++++++++++- tests/integration/test_writes.py | 6 +- 3 files changed, 165 insertions(+), 23 deletions(-) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 4f7e2f2598..86906e7257 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -2505,7 +2505,7 @@ def generate_data_file_filename(self, extension: str) -> str: def generate_data_file_path(self, extension: str) -> str: if self.partition_key: - file_path = f"{self.generate_data_file_partition_path()}/{self. generate_data_file_filename(extension)}" + file_path = f"{self.generate_data_file_partition_path()}/{self.generate_data_file_filename(extension)}" return file_path else: return self.generate_data_file_filename(extension) @@ -2549,7 +2549,7 @@ def _dataframe_to_data_files( "Fail to get neither TableProperties.WRITE_TARGET_FILE_SIZE_BYTES nor WRITE_TARGET_FILE_SIZE_BYTES_DEFAULT for writing target data file." ) - if any(len(spec.fields) > 0 for spec in table_metadata.partition_specs): + if len(table_metadata.spec().fields) > 0: partitions = partition(table_metadata, df) yield from write_file( io=io, @@ -3155,15 +3155,13 @@ def _get_partition_sort_order(partition_columns: list[str], reverse: bool = Fals def group_by_partition_scheme( iceberg_table_metadata: TableMetadata, arrow_table: pa.Table, partition_columns: list[str] ) -> pa.Table: - """Given a table sort it by current partition scheme with all transform functions supported.""" + """Given a table, sort it by current partition scheme.""" from pyiceberg.transforms import IdentityTransform supported = {IdentityTransform} - if not all( - type(field.transform) in supported for field in iceberg_table_metadata.spec().fields if field in partition_columns - ): + if not all(type(field.transform) in supported for field in iceberg_table_metadata.spec().fields): raise ValueError( - f"Not all transforms are supported, get: {[transform in supported for transform in iceberg_table_metadata.spec().fields]}." + f"Not all transforms are supported, expected: {supported}, but get: {[str(field) for field in iceberg_table_metadata.spec().fields if field.transform not in supported]}." ) # only works for identity @@ -3172,15 +3170,12 @@ def group_by_partition_scheme( return sorted_arrow_table -def get_partition_columns(iceberg_table_metadata: TableMetadata, arrow_table: pa.Table) -> list[str]: - arrow_table_cols = set(arrow_table.column_names) +def get_partition_columns(iceberg_table_metadata: TableMetadata) -> list[str]: partition_cols = [] - for transform_field in iceberg_table_metadata.spec().fields: - column_name = iceberg_table_metadata.schema().find_column_name(transform_field.source_id) + for partition_field in iceberg_table_metadata.spec().fields: + column_name = iceberg_table_metadata.schema().find_column_name(partition_field.source_id) if not column_name: - raise ValueError(f"{transform_field=} could not be found in {iceberg_table_metadata.schema()}.") - if column_name not in arrow_table_cols: - continue + raise ValueError(f"{partition_field=} could not be found in {iceberg_table_metadata.schema()}.") partition_cols.append(column_name) return partition_cols @@ -3240,7 +3235,7 @@ def partition(iceberg_table_metadata: TableMetadata, arrow_table: pa.Table) -> I """ import pyarrow as pa - partition_columns = get_partition_columns(iceberg_table_metadata, arrow_table) + partition_columns = get_partition_columns(iceberg_table_metadata) arrow_table = group_by_partition_scheme(iceberg_table_metadata, arrow_table, partition_columns) reversing_sort_order_options = _get_partition_sort_order(partition_columns, reverse=True) diff --git a/tests/integration/test_partitioned_write.py b/tests/integration/test_partitioned_write.py index dbe5042dd2..f846df3c23 100644 --- a/tests/integration/test_partitioned_write.py +++ b/tests/integration/test_partitioned_write.py @@ -114,6 +114,18 @@ def arrow_table_with_null() -> pa.Table: return pa.Table.from_pydict(TEST_DATA_WITH_NULL, schema=pa_schema) +@pytest.fixture(scope="session") +def arrow_table_without_data(pa_schema: pa.Schema) -> pa.Table: + """PyArrow table with all kinds of columns""" + return pa.Table.from_pylist([], schema=pa_schema) + + +@pytest.fixture(scope="session") +def arrow_table_with_only_nulls(pa_schema: pa.Schema) -> pa.Table: + """PyArrow table with all kinds of columns""" + return pa.Table.from_pylist([{}, {}], schema=pa_schema) + + @pytest.fixture(scope="session", autouse=True) def table_v1_with_null_partitioned(session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None: partition_cols = [ @@ -151,6 +163,77 @@ def table_v1_with_null_partitioned(session_catalog: Catalog, arrow_table_with_nu assert tbl.format_version == 1, f"Expected v1, got: v{tbl.format_version}" +@pytest.fixture(scope="session", autouse=True) +def table_v1_without_data_partitioned(session_catalog: Catalog, arrow_table_without_data: pa.Table) -> None: + partition_cols = [ + 'int', + 'bool', + 'string', + "string_long", + "long", + "float", + "double", + "date", + "timestamptz", + "timestamp", + "binary", + ] + for partition_col in partition_cols: + identifier = f"default.arrow_table_v1_without_data_partitioned_on_col_{partition_col}" + + try: + session_catalog.drop_table(identifier=identifier) + except NoSuchTableError: + pass + nested_field = TABLE_SCHEMA.find_field(partition_col) + source_id = nested_field.field_id + tbl = session_catalog.create_table( + identifier=identifier, + schema=TABLE_SCHEMA, + partition_spec=PartitionSpec( + PartitionField(source_id=source_id, field_id=1001, transform=IdentityTransform(), name=partition_col) + ), + properties={'format-version': '1'}, + ) + tbl.append(arrow_table_without_data) + assert tbl.format_version == 1, f"Expected v1, got: v{tbl.format_version}" + + +@pytest.fixture(scope="session", autouse=True) +def table_v1_with_only_nulls_partitioned(session_catalog: Catalog, arrow_table_with_only_nulls: pa.Table) -> None: + partition_cols = [ + 'int', + 'bool', + 'string', + "string_long", + "long", + "float", + "double", + "date", + "timestamptz", + "timestamp", + "binary", + ] + for partition_col in partition_cols: + identifier = f"default.arrow_table_v1_with_only_nulls_partitioned_on_col_{partition_col}" + try: + session_catalog.drop_table(identifier=identifier) + except NoSuchTableError: + pass + nested_field = TABLE_SCHEMA.find_field(partition_col) + source_id = nested_field.field_id + tbl = session_catalog.create_table( + identifier=identifier, + schema=TABLE_SCHEMA, + partition_spec=PartitionSpec( + PartitionField(source_id=source_id, field_id=1001, transform=IdentityTransform(), name=partition_col) + ), + properties={'format-version': '1'}, + ) + tbl.append(arrow_table_with_only_nulls) + assert tbl.format_version == 1, f"Expected v1, got: v{tbl.format_version}" + + @pytest.fixture(scope="session", autouse=True) def table_v1_appended_with_null_partitioned(session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None: partition_cols = [ @@ -214,7 +297,6 @@ def table_v2_with_null_partitioned(session_catalog: Catalog, arrow_table_with_nu pass nested_field = TABLE_SCHEMA.find_field(partition_col) source_id = nested_field.field_id - tbl = session_catalog.create_table( identifier=identifier, schema=TABLE_SCHEMA, @@ -228,6 +310,42 @@ def table_v2_with_null_partitioned(session_catalog: Catalog, arrow_table_with_nu assert tbl.format_version == 2, f"Expected v2, got: v{tbl.format_version}" +@pytest.fixture(scope="session", autouse=True) +def table_v2_without_data_partitioned(session_catalog: Catalog, arrow_table_without_data: pa.Table) -> None: + partition_cols = [ + 'int', + 'bool', + 'string', + "string_long", + "long", + "float", + "double", + "date", + "timestamptz", + "timestamp", + "binary", + ] + for partition_col in partition_cols: + identifier = f"default.arrow_table_v2_without_data_partitioned_on_col_{partition_col}" + + try: + session_catalog.drop_table(identifier=identifier) + except NoSuchTableError: + pass + nested_field = TABLE_SCHEMA.find_field(partition_col) + source_id = nested_field.field_id + tbl = session_catalog.create_table( + identifier=identifier, + schema=TABLE_SCHEMA, + partition_spec=PartitionSpec( + PartitionField(source_id=source_id, field_id=1001, transform=IdentityTransform(), name=partition_col) + ), + properties={'format-version': '2'}, + ) + tbl.append(arrow_table_without_data) + assert tbl.format_version == 2, f"Expected v2, got: v{tbl.format_version}" + + @pytest.fixture(scope="session", autouse=True) def table_v2_appended_with_null_partitioned(session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None: partition_cols = [ @@ -320,9 +438,36 @@ def table_v1_v2_appended_with_null(session_catalog: Catalog, arrow_table_with_nu def test_query_filter_null_partitioned(spark: SparkSession, part_col: str, format_version: int) -> None: identifier = f"default.arrow_table_v{format_version}_with_null_partitioned_on_col_{part_col}" df = spark.table(identifier) - df.show(20, False) + assert df.count() == 3, f"Expected 3 total rows for {identifier}" + for col in TEST_DATA_WITH_NULL.keys(): + assert df.where(f"{col} is not null").count() == 2, f"Expected 2 non-null rows for {col}" + assert df.where(f"{col} is null").count() == 1, f"Expected 1 null row for {col} is null" + + +@pytest.mark.integration +@pytest.mark.parametrize( + "part_col", ['int', 'bool', 'string', "string_long", "long", "float", "double", "date", 'timestamp', 'timestamptz'] +) +@pytest.mark.parametrize("format_version", [1, 2]) +def test_query_filter_without_data_partitioned(spark: SparkSession, part_col: str, format_version: int) -> None: + identifier = f"default.arrow_table_v{format_version}_without_data_partitioned_on_col_{part_col}" + df = spark.table(identifier) + for col in TEST_DATA_WITH_NULL.keys(): + assert df.where(f"{col} is null").count() == 0, f"Expected 0 row for {col}" + assert df.where(f"{col} is not null").count() == 0, f"Expected 0 row for {col}" + + +@pytest.mark.integration +@pytest.mark.parametrize( + "part_col", ['int', 'bool', 'string', "string_long", "long", "float", "double", "date", 'timestamp', 'timestamptz'] +) +@pytest.mark.parametrize("format_version", [1, 2]) +def test_query_filter_only_nulls_partitioned(spark: SparkSession, part_col: str, format_version: int) -> None: + identifier = f"default.arrow_table_v1_with_only_nulls_partitioned_on_col_{part_col}" + df = spark.table(identifier) for col in TEST_DATA_WITH_NULL.keys(): - assert df.where(f"{col} is not null").count() == 2, f"Expected 2 rows for {col}" + assert df.where(f"{col} is null").count() == 2, f"Expected 2 row for {col}" + assert df.where(f"{col} is not null").count() == 0, f"Expected 0 rows for {col}" @pytest.mark.integration @@ -335,7 +480,8 @@ def test_query_filter_appended_null_partitioned(spark: SparkSession, part_col: s df = spark.table(identifier) for col in TEST_DATA_WITH_NULL.keys(): df = spark.table(identifier) - assert df.where(f"{col} is not null").count() == 4, f"Expected 6 rows for {col}" + assert df.where(f"{col} is not null").count() == 4, f"Expected 4 non-null rows for {col}" + assert df.where(f"{col} is null").count() == 2, f"Expected 2 null rows for {col}" @pytest.mark.integration @@ -347,7 +493,8 @@ def test_query_filter_v1_v2_append_null(spark: SparkSession, part_col: str) -> N df = spark.table(identifier) for col in TEST_DATA_WITH_NULL.keys(): df = spark.table(identifier) - assert df.where(f"{col} is not null").count() == 4, f"Expected 4 row for {col}" + assert df.where(f"{col} is not null").count() == 4, f"Expected 4 non-null rows for {col}" + assert df.where(f"{col} is null").count() == 2, f"Expected 2 null rows for {col}" def test_summaries_with_null(spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None: diff --git a/tests/integration/test_writes.py b/tests/integration/test_writes.py index 0186e662dc..5dd80588ca 100644 --- a/tests/integration/test_writes.py +++ b/tests/integration/test_writes.py @@ -247,7 +247,7 @@ def test_query_filter_without_data(spark: SparkSession, col: str, format_version identifier = f"default.arrow_table_v{format_version}_without_data" df = spark.table(identifier) assert df.where(f"{col} is null").count() == 0, f"Expected 0 row for {col}" - assert df.where(f"{col} is not null").count() == 0, f"Expected 0 rows for {col}" + assert df.where(f"{col} is not null").count() == 0, f"Expected 0 row for {col}" @pytest.mark.integration @@ -256,8 +256,8 @@ def test_query_filter_without_data(spark: SparkSession, col: str, format_version def test_query_filter_only_nulls(spark: SparkSession, col: str, format_version: int) -> None: identifier = f"default.arrow_table_v{format_version}_with_only_nulls" df = spark.table(identifier) - assert df.where(f"{col} is null").count() == 2, f"Expected 2 row for {col}" - assert df.where(f"{col} is not null").count() == 0, f"Expected 0 rows for {col}" + assert df.where(f"{col} is null").count() == 2, f"Expected 2 rows for {col}" + assert df.where(f"{col} is not null").count() == 0, f"Expected 0 row for {col}" @pytest.mark.integration From aecd7adc25c09ee3de0fdc8ff2ab2d78bd4f804d Mon Sep 17 00:00:00 2001 From: Adrian Qin <147659252+jqin61@users.noreply.github.com> Date: Fri, 29 Mar 2024 21:58:00 +0000 Subject: [PATCH 04/15] tests for unsupported transforms; unit tests for partition slicing algorithm --- pyiceberg/table/__init__.py | 48 ++++---- pyiceberg/typedef.py | 4 + tests/integration/test_partitioned_write.py | 118 +++++++++++++++++++- tests/integration/test_partitioning_key.py | 2 +- tests/table/test_init.py | 42 ++++++- 5 files changed, 180 insertions(+), 34 deletions(-) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 86906e7257..8672f70465 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -1128,6 +1128,12 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT) if not isinstance(df, pa.Table): raise ValueError(f"Expected PyArrow table, got: {df}") + supported = {IdentityTransform} + if not all(type(field.transform) in supported for field in self.metadata.spec().fields): + raise ValueError( + f"All transforms are not supported, expected: {supported}, but get: {[str(field) for field in self.metadata.spec().fields if field.transform not in supported]}." + ) + _check_schema_compatible(self.schema(), other_schema=df.schema) # cast if the two schemas are compatible but not equal if self.schema().as_arrow() != df.schema: @@ -2550,7 +2556,7 @@ def _dataframe_to_data_files( ) if len(table_metadata.spec().fields) > 0: - partitions = partition(table_metadata, df) + partitions = partition(spec=table_metadata.spec(), schema=table_metadata.schema(), arrow_table=df) yield from write_file( io=io, table_metadata=table_metadata, @@ -3152,30 +3158,23 @@ def _get_partition_sort_order(partition_columns: list[str], reverse: bool = Fals return {'sort_keys': [(column_name, order) for column_name in partition_columns], 'null_placement': null_placement} -def group_by_partition_scheme( - iceberg_table_metadata: TableMetadata, arrow_table: pa.Table, partition_columns: list[str] -) -> pa.Table: +def group_by_partition_scheme(arrow_table: pa.Table, partition_columns: list[str]) -> pa.Table: """Given a table, sort it by current partition scheme.""" - from pyiceberg.transforms import IdentityTransform - - supported = {IdentityTransform} - if not all(type(field.transform) in supported for field in iceberg_table_metadata.spec().fields): - raise ValueError( - f"Not all transforms are supported, expected: {supported}, but get: {[str(field) for field in iceberg_table_metadata.spec().fields if field.transform not in supported]}." - ) - - # only works for identity + # only works for identity for now sort_options = _get_partition_sort_order(partition_columns, reverse=False) sorted_arrow_table = arrow_table.sort_by(sorting=sort_options['sort_keys'], null_placement=sort_options['null_placement']) return sorted_arrow_table -def get_partition_columns(iceberg_table_metadata: TableMetadata) -> list[str]: +def get_partition_columns( + spec: PartitionSpec, + schema: Schema, +) -> list[str]: partition_cols = [] - for partition_field in iceberg_table_metadata.spec().fields: - column_name = iceberg_table_metadata.schema().find_column_name(partition_field.source_id) + for partition_field in spec.fields: + column_name = schema.find_column_name(partition_field.source_id) if not column_name: - raise ValueError(f"{partition_field=} could not be found in {iceberg_table_metadata.schema()}.") + raise ValueError(f"{partition_field=} could not be found in {schema}.") partition_cols.append(column_name) return partition_cols @@ -3199,19 +3198,18 @@ def _get_table_partitions( } table_partitions = [] - for inst in sorted_slice_instructions: + for idx, inst in enumerate(sorted_slice_instructions): partition_slice = arrow_table.slice(**inst) fieldvalues = [ - PartitionFieldValue(partition_field, projected_and_filtered[partition_field.source_id][inst["offset"]]) + PartitionFieldValue(partition_field, projected_and_filtered[partition_field.source_id][idx]) for partition_field in partition_fields ] partition_key = PartitionKey(raw_partition_field_values=fieldvalues, partition_spec=partition_spec, schema=schema) table_partitions.append(TablePartition(partition_key=partition_key, arrow_table_partition=partition_slice)) - return table_partitions -def partition(iceberg_table_metadata: TableMetadata, arrow_table: pa.Table) -> Iterable[TablePartition]: +def partition(spec: PartitionSpec, schema: Schema, arrow_table: pa.Table) -> Iterable[TablePartition]: """Based on the iceberg table partition spec, slice the arrow table into partitions with their keys. Example: @@ -3235,8 +3233,8 @@ def partition(iceberg_table_metadata: TableMetadata, arrow_table: pa.Table) -> I """ import pyarrow as pa - partition_columns = get_partition_columns(iceberg_table_metadata) - arrow_table = group_by_partition_scheme(iceberg_table_metadata, arrow_table, partition_columns) + partition_columns = get_partition_columns(spec=spec, schema=schema) + arrow_table = group_by_partition_scheme(arrow_table, partition_columns) reversing_sort_order_options = _get_partition_sort_order(partition_columns, reverse=True) reversed_indices = pa.compute.sort_indices(arrow_table, **reversing_sort_order_options).to_pylist() @@ -3252,8 +3250,6 @@ def partition(iceberg_table_metadata: TableMetadata, arrow_table: pa.Table) -> I last = reversed_indices[ptr] ptr = ptr + group_size - table_partitions: list[TablePartition] = _get_table_partitions( - arrow_table, iceberg_table_metadata.spec(), iceberg_table_metadata.schema(), slice_instructions - ) + table_partitions: list[TablePartition] = _get_table_partitions(arrow_table, spec, schema, slice_instructions) return table_partitions diff --git a/pyiceberg/typedef.py b/pyiceberg/typedef.py index e57bf3490c..9d8633fbcf 100644 --- a/pyiceberg/typedef.py +++ b/pyiceberg/typedef.py @@ -199,3 +199,7 @@ def __repr__(self) -> str: def record_fields(self) -> List[str]: """Return values of all the fields of the Record class except those specified in skip_fields.""" return [self.__getattribute__(v) if hasattr(self, v) else None for v in self._position_to_field_name] + + def __hash__(self) -> int: + """Return hash value of the Record class.""" + return hash(str(self)) diff --git a/tests/integration/test_partitioned_write.py b/tests/integration/test_partitioned_write.py index f846df3c23..667a826773 100644 --- a/tests/integration/test_partitioned_write.py +++ b/tests/integration/test_partitioned_write.py @@ -27,7 +27,15 @@ from pyiceberg.exceptions import NoSuchTableError from pyiceberg.partitioning import PartitionField, PartitionSpec from pyiceberg.schema import Schema -from pyiceberg.transforms import IdentityTransform +from pyiceberg.transforms import ( + BucketTransform, + DayTransform, + HourTransform, + IdentityTransform, + MonthTransform, + TruncateTransform, + YearTransform, +) from pyiceberg.types import ( BinaryType, BooleanType, @@ -91,9 +99,8 @@ @pytest.fixture(scope="session") -def arrow_table_with_null() -> pa.Table: - """PyArrow table with all kinds of columns""" - pa_schema = pa.schema([ +def pa_schema() -> pa.Schema: + return pa.schema([ ("bool", pa.bool_()), ("string", pa.string()), ("string_long", pa.string()), @@ -111,6 +118,11 @@ def arrow_table_with_null() -> pa.Table: ("binary", pa.large_binary()), ("fixed", pa.binary(16)), ]) + + +@pytest.fixture(scope="session") +def arrow_table_with_null(pa_schema: Schema) -> pa.Table: + """PyArrow table with all kinds of columns""" return pa.Table.from_pydict(TEST_DATA_WITH_NULL, schema=pa_schema) @@ -346,6 +358,41 @@ def table_v2_without_data_partitioned(session_catalog: Catalog, arrow_table_with assert tbl.format_version == 2, f"Expected v2, got: v{tbl.format_version}" +@pytest.fixture(scope="session", autouse=True) +def table_v2_with_only_nulls_partitioned(session_catalog: Catalog, arrow_table_with_only_nulls: pa.Table) -> None: + partition_cols = [ + 'int', + 'bool', + 'string', + "string_long", + "long", + "float", + "double", + "date", + "timestamptz", + "timestamp", + "binary", + ] + for partition_col in partition_cols: + identifier = f"default.arrow_table_v2_with_only_nulls_partitioned_on_col_{partition_col}" + try: + session_catalog.drop_table(identifier=identifier) + except NoSuchTableError: + pass + nested_field = TABLE_SCHEMA.find_field(partition_col) + source_id = nested_field.field_id + tbl = session_catalog.create_table( + identifier=identifier, + schema=TABLE_SCHEMA, + partition_spec=PartitionSpec( + PartitionField(source_id=source_id, field_id=1001, transform=IdentityTransform(), name=partition_col) + ), + properties={'format-version': '2'}, + ) + tbl.append(arrow_table_with_only_nulls) + assert tbl.format_version == 2, f"Expected v2, got: v{tbl.format_version}" + + @pytest.fixture(scope="session", autouse=True) def table_v2_appended_with_null_partitioned(session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None: partition_cols = [ @@ -463,7 +510,7 @@ def test_query_filter_without_data_partitioned(spark: SparkSession, part_col: st ) @pytest.mark.parametrize("format_version", [1, 2]) def test_query_filter_only_nulls_partitioned(spark: SparkSession, part_col: str, format_version: int) -> None: - identifier = f"default.arrow_table_v1_with_only_nulls_partitioned_on_col_{part_col}" + identifier = f"default.arrow_table_v{format_version}_with_only_nulls_partitioned_on_col_{part_col}" df = spark.table(identifier) for col in TEST_DATA_WITH_NULL.keys(): assert df.where(f"{col} is null").count() == 2, f"Expected 2 row for {col}" @@ -591,7 +638,7 @@ def test_data_files_with_table_partitioned_with_null( @pytest.mark.integration -def test_invalid_arguments(spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None: +def test_invalid_arguments(spark: SparkSession, session_catalog: Catalog) -> None: identifier = "default.arrow_data_files" try: @@ -608,3 +655,62 @@ def test_invalid_arguments(spark: SparkSession, session_catalog: Catalog, arrow_ with pytest.raises(ValueError, match="Expected PyArrow table, got: not a df"): tbl.append("not a df") + + +@pytest.mark.integration +@pytest.mark.parametrize( + "spec", + [ + # mixed with non-identity is not supported + ( + PartitionSpec( + PartitionField(source_id=4, field_id=1001, transform=BucketTransform(2), name="int_bucket"), + PartitionField(source_id=1, field_id=1002, transform=IdentityTransform(), name="bool"), + ) + ), + # none of non-identity is supported + (PartitionSpec(PartitionField(source_id=4, field_id=1001, transform=BucketTransform(2), name="int_bucket"))), + (PartitionSpec(PartitionField(source_id=5, field_id=1001, transform=BucketTransform(2), name="long_bucket"))), + (PartitionSpec(PartitionField(source_id=10, field_id=1001, transform=BucketTransform(2), name="date_bucket"))), + (PartitionSpec(PartitionField(source_id=8, field_id=1001, transform=BucketTransform(2), name="timestamp_bucket"))), + (PartitionSpec(PartitionField(source_id=9, field_id=1001, transform=BucketTransform(2), name="timestamptz_bucket"))), + (PartitionSpec(PartitionField(source_id=2, field_id=1001, transform=BucketTransform(2), name="string_bucket"))), + (PartitionSpec(PartitionField(source_id=12, field_id=1001, transform=BucketTransform(2), name="fixed_bucket"))), + (PartitionSpec(PartitionField(source_id=11, field_id=1001, transform=BucketTransform(2), name="binary_bucket"))), + (PartitionSpec(PartitionField(source_id=4, field_id=1001, transform=TruncateTransform(2), name="int_trunc"))), + (PartitionSpec(PartitionField(source_id=5, field_id=1001, transform=TruncateTransform(2), name="long_trunc"))), + (PartitionSpec(PartitionField(source_id=2, field_id=1001, transform=TruncateTransform(2), name="string_trunc"))), + (PartitionSpec(PartitionField(source_id=11, field_id=1001, transform=TruncateTransform(2), name="binary_trunc"))), + (PartitionSpec(PartitionField(source_id=8, field_id=1001, transform=YearTransform(), name="timestamp_year"))), + (PartitionSpec(PartitionField(source_id=9, field_id=1001, transform=YearTransform(), name="timestamptz_year"))), + (PartitionSpec(PartitionField(source_id=10, field_id=1001, transform=YearTransform(), name="date_year"))), + (PartitionSpec(PartitionField(source_id=8, field_id=1001, transform=MonthTransform(), name="timestamp_month"))), + (PartitionSpec(PartitionField(source_id=9, field_id=1001, transform=MonthTransform(), name="timestamptz_month"))), + (PartitionSpec(PartitionField(source_id=10, field_id=1001, transform=MonthTransform(), name="date_month"))), + (PartitionSpec(PartitionField(source_id=8, field_id=1001, transform=DayTransform(), name="timestamp_day"))), + (PartitionSpec(PartitionField(source_id=9, field_id=1001, transform=DayTransform(), name="timestamptz_day"))), + (PartitionSpec(PartitionField(source_id=10, field_id=1001, transform=DayTransform(), name="date_day"))), + (PartitionSpec(PartitionField(source_id=8, field_id=1001, transform=HourTransform(), name="timestamp_hour"))), + (PartitionSpec(PartitionField(source_id=9, field_id=1001, transform=HourTransform(), name="timestamptz_hour"))), + (PartitionSpec(PartitionField(source_id=10, field_id=1001, transform=HourTransform(), name="date_hour"))), + ], +) +def test_unsupported_transform( + spec: PartitionSpec, spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table +) -> None: + identifier = "default.unsupported_transform" + + try: + session_catalog.drop_table(identifier=identifier) + except NoSuchTableError: + pass + + tbl = session_catalog.create_table( + identifier=identifier, + schema=TABLE_SCHEMA, + partition_spec=spec, + properties={'format-version': '1'}, + ) + + with pytest.raises(ValueError, match="All transforms are not supported.*"): + tbl.append(arrow_table_with_null) diff --git a/tests/integration/test_partitioning_key.py b/tests/integration/test_partitioning_key.py index 12056bac1e..d89ecaf202 100644 --- a/tests/integration/test_partitioning_key.py +++ b/tests/integration/test_partitioning_key.py @@ -749,7 +749,7 @@ def test_partition_key( # key.to_path() generates the hive partitioning part of the to-write parquet file path assert key.to_path() == expected_hive_partition_path_slice - # Justify expected values are not made up but conform to spark behaviors + # Justify expected values are not made up but conforming to spark behaviors if spark_create_table_sql_for_justification is not None and spark_data_insert_sql_for_justification is not None: try: spark.sql(f"drop table {identifier}") diff --git a/tests/table/test_init.py b/tests/table/test_init.py index f1191295f3..57618ed5f7 100644 --- a/tests/table/test_init.py +++ b/tests/table/test_init.py @@ -66,6 +66,7 @@ _check_schema_compatible, _match_deletes_to_data_file, _TableMetadataUpdateContext, + partition, update_table_metadata, ) from pyiceberg.table.metadata import INITIAL_SEQUENCE_NUMBER, TableMetadataUtil, TableMetadataV2, _generate_snapshot_id @@ -82,7 +83,11 @@ SortField, SortOrder, ) -from pyiceberg.transforms import BucketTransform, IdentityTransform +from pyiceberg.transforms import ( + BucketTransform, + IdentityTransform, +) +from pyiceberg.typedef import Record from pyiceberg.types import ( BinaryType, BooleanType, @@ -1139,3 +1144,38 @@ def test_serialize_commit_table_request() -> None: deserialized_request = CommitTableRequest.model_validate_json(request.model_dump_json()) assert request == deserialized_request + + + +def test_partition() -> None: + import pyarrow as pa + + test_pa_schema = pa.schema([('year', pa.int64()), ("n_legs", pa.int64()), ("animal", pa.string())]) + test_schema = Schema( + NestedField(field_id=1, name='year', field_type=StringType(), required=False), + NestedField(field_id=2, name='n_legs', field_type=IntegerType(), required=True), + NestedField(field_id=3, name='animal', field_type=StringType(), required=False), + schema_id=1, + ) + test_data = { + 'year': [2020, 2022, 2022, 2022, 2021, 2022, 2022, 2019, 2021], + 'n_legs': [2, 2, 2, 4, 4, 4, 4, 5, 100], + 'animal': ["Flamingo", "Parrot", "Parrot", "Horse", "Dog", "Horse", "Horse", "Brittle stars", "Centipede"], + } + arrow_table = pa.Table.from_pydict(test_data, schema=test_pa_schema) + partition_spec = PartitionSpec( + PartitionField(source_id=2, field_id=1002, transform=IdentityTransform(), name="n_legs_identity"), + PartitionField(source_id=1, field_id=1001, transform=IdentityTransform(), name="year_identity"), + ) + result = partition(partition_spec, test_schema, arrow_table) + assert {table_partition.partition_key.partition for table_partition in result} == { + Record(n_legs_identity=2, year_identity=2020), + Record(n_legs_identity=100, year_identity=2021), + Record(n_legs_identity=4, year_identity=2021), + Record(n_legs_identity=4, year_identity=2022), + Record(n_legs_identity=2, year_identity=2022), + Record(n_legs_identity=5, year_identity=2019), + } + assert ( + pa.concat_tables([table_partition.arrow_table_partition for table_partition in result]).num_rows == arrow_table.num_rows + ) From d5f39f3385eebf50df709ff9562b4056455e8a52 Mon Sep 17 00:00:00 2001 From: Adrian Qin <147659252+jqin61@users.noreply.github.com> Date: Fri, 29 Mar 2024 23:54:02 +0000 Subject: [PATCH 05/15] add a comprehensive partition unit test --- tests/table/test_init.py | 59 +++++++++++++++++++++++++--------------- 1 file changed, 37 insertions(+), 22 deletions(-) diff --git a/tests/table/test_init.py b/tests/table/test_init.py index 57618ed5f7..1e238e4cae 100644 --- a/tests/table/test_init.py +++ b/tests/table/test_init.py @@ -1146,36 +1146,51 @@ def test_serialize_commit_table_request() -> None: assert request == deserialized_request - -def test_partition() -> None: +def test_identity_partition_on_multi_columns() -> None: import pyarrow as pa - test_pa_schema = pa.schema([('year', pa.int64()), ("n_legs", pa.int64()), ("animal", pa.string())]) + test_pa_schema = pa.schema([('born_year', pa.int64()), ("n_legs", pa.int64()), ("animal", pa.string())]) test_schema = Schema( - NestedField(field_id=1, name='year', field_type=StringType(), required=False), + NestedField(field_id=1, name='born_year', field_type=StringType(), required=False), NestedField(field_id=2, name='n_legs', field_type=IntegerType(), required=True), NestedField(field_id=3, name='animal', field_type=StringType(), required=False), schema_id=1, ) - test_data = { - 'year': [2020, 2022, 2022, 2022, 2021, 2022, 2022, 2019, 2021], - 'n_legs': [2, 2, 2, 4, 4, 4, 4, 5, 100], - 'animal': ["Flamingo", "Parrot", "Parrot", "Horse", "Dog", "Horse", "Horse", "Brittle stars", "Centipede"], - } - arrow_table = pa.Table.from_pydict(test_data, schema=test_pa_schema) + # 5 partitions, 6 unique row values, 12 rows + test_rows = [ + (2021, 4, "Dog"), + (2022, 4, "Horse"), + (2022, 4, "Another Horse"), + (2021, 100, "Centipede"), + (None, 4, "Kirin"), + (2021, None, "Fish"), + ] * 2 + expected = {Record(n_legs_identity=test_rows[i][1], year_identity=test_rows[i][0]) for i in range(len(test_rows))} partition_spec = PartitionSpec( PartitionField(source_id=2, field_id=1002, transform=IdentityTransform(), name="n_legs_identity"), PartitionField(source_id=1, field_id=1001, transform=IdentityTransform(), name="year_identity"), ) - result = partition(partition_spec, test_schema, arrow_table) - assert {table_partition.partition_key.partition for table_partition in result} == { - Record(n_legs_identity=2, year_identity=2020), - Record(n_legs_identity=100, year_identity=2021), - Record(n_legs_identity=4, year_identity=2021), - Record(n_legs_identity=4, year_identity=2022), - Record(n_legs_identity=2, year_identity=2022), - Record(n_legs_identity=5, year_identity=2019), - } - assert ( - pa.concat_tables([table_partition.arrow_table_partition for table_partition in result]).num_rows == arrow_table.num_rows - ) + import random + + # there are 12! / ((2!)^6) = 7,484,400 permutations, too many to pick all + for _ in range(1000): + random.shuffle(test_rows) + test_data = { + 'born_year': [row[0] for row in test_rows], + 'n_legs': [row[1] for row in test_rows], + 'animal': [row[2] for row in test_rows], + } + arrow_table = pa.Table.from_pydict(test_data, schema=test_pa_schema) + + result = partition(partition_spec, test_schema, arrow_table) + + assert {table_partition.partition_key.partition for table_partition in result} == expected + assert ( + pa.concat_tables([table_partition.arrow_table_partition for table_partition in result]).num_rows + == arrow_table.num_rows + ) + assert pa.concat_tables([table_partition.arrow_table_partition for table_partition in result]).sort_by([ + ('born_year', 'ascending'), + ('n_legs', 'ascending'), + ('animal', 'ascending'), + ]) == arrow_table.sort_by([('born_year', 'ascending'), ('n_legs', 'ascending'), ('animal', 'ascending')]) From fd484ef7f153e124d7564ff4deebff9c6455972d Mon Sep 17 00:00:00 2001 From: Adrian Qin <147659252+jqin61@users.noreply.github.com> Date: Sat, 30 Mar 2024 00:09:47 +0000 Subject: [PATCH 06/15] clean up --- tests/table/test_init.py | 42 +++++++++++++++++++++++++++++++++++----- 1 file changed, 37 insertions(+), 5 deletions(-) diff --git a/tests/table/test_init.py b/tests/table/test_init.py index 1e238e4cae..15fd7d0b8b 100644 --- a/tests/table/test_init.py +++ b/tests/table/test_init.py @@ -1146,6 +1146,40 @@ def test_serialize_commit_table_request() -> None: assert request == deserialized_request +def test_partition_for_demo() -> None: + import pyarrow as pa + + test_pa_schema = pa.schema([('year', pa.int64()), ("n_legs", pa.int64()), ("animal", pa.string())]) + test_schema = Schema( + NestedField(field_id=1, name='year', field_type=StringType(), required=False), + NestedField(field_id=2, name='n_legs', field_type=IntegerType(), required=True), + NestedField(field_id=3, name='animal', field_type=StringType(), required=False), + schema_id=1, + ) + test_data = { + 'year': [2020, 2022, 2022, 2022, 2021, 2022, 2022, 2019, 2021], + 'n_legs': [2, 2, 2, 4, 4, 4, 4, 5, 100], + 'animal': ["Flamingo", "Parrot", "Parrot", "Horse", "Dog", "Horse", "Horse", "Brittle stars", "Centipede"], + } + arrow_table = pa.Table.from_pydict(test_data, schema=test_pa_schema) + partition_spec = PartitionSpec( + PartitionField(source_id=2, field_id=1002, transform=IdentityTransform(), name="n_legs_identity"), + PartitionField(source_id=1, field_id=1001, transform=IdentityTransform(), name="year_identity"), + ) + result = partition(partition_spec, test_schema, arrow_table) + assert {table_partition.partition_key.partition for table_partition in result} == { + Record(n_legs_identity=2, year_identity=2020), + Record(n_legs_identity=100, year_identity=2021), + Record(n_legs_identity=4, year_identity=2021), + Record(n_legs_identity=4, year_identity=2022), + Record(n_legs_identity=2, year_identity=2022), + Record(n_legs_identity=5, year_identity=2019), + } + assert ( + pa.concat_tables([table_partition.arrow_table_partition for table_partition in result]).num_rows == arrow_table.num_rows + ) + + def test_identity_partition_on_multi_columns() -> None: import pyarrow as pa @@ -1185,11 +1219,9 @@ def test_identity_partition_on_multi_columns() -> None: result = partition(partition_spec, test_schema, arrow_table) assert {table_partition.partition_key.partition for table_partition in result} == expected - assert ( - pa.concat_tables([table_partition.arrow_table_partition for table_partition in result]).num_rows - == arrow_table.num_rows - ) - assert pa.concat_tables([table_partition.arrow_table_partition for table_partition in result]).sort_by([ + concatenated_arrow_table = pa.concat_tables([table_partition.arrow_table_partition for table_partition in result]) + assert concatenated_arrow_table.num_rows == arrow_table.num_rows + assert concatenated_arrow_table.sort_by([ ('born_year', 'ascending'), ('n_legs', 'ascending'), ('animal', 'ascending'), From e8c9334f781e29e5fa95effcca6e082ae9301f61 Mon Sep 17 00:00:00 2001 From: Adrian Qin <147659252+jqin61@users.noreply.github.com> Date: Mon, 1 Apr 2024 15:19:45 +0000 Subject: [PATCH 07/15] move common fixtures utils to utils.py and conftest --- tests/integration/test_writes/conftest.py | 42 ++++++++ .../test_partitioned_writes.py} | 101 +++--------------- .../{ => test_writes}/test_writes.py | 85 +-------------- tests/integration/test_writes/utils.py | 69 ++++++++++++ 4 files changed, 126 insertions(+), 171 deletions(-) create mode 100644 tests/integration/test_writes/conftest.py rename tests/integration/{test_partitioned_write.py => test_writes/test_partitioned_writes.py} (85%) rename tests/integration/{ => test_writes}/test_writes.py (88%) create mode 100644 tests/integration/test_writes/utils.py diff --git a/tests/integration/test_writes/conftest.py b/tests/integration/test_writes/conftest.py new file mode 100644 index 0000000000..4f98c20660 --- /dev/null +++ b/tests/integration/test_writes/conftest.py @@ -0,0 +1,42 @@ +import pyarrow as pa +from utils import TEST_DATA_WITH_NULL +import pytest + +@pytest.fixture(scope="session") +def pa_schema() -> pa.Schema: + return pa.schema([ + ("bool", pa.bool_()), + ("string", pa.string()), + ("string_long", pa.string()), + ("int", pa.int32()), + ("long", pa.int64()), + ("float", pa.float32()), + ("double", pa.float64()), + ("timestamp", pa.timestamp(unit="us")), + ("timestamptz", pa.timestamp(unit="us", tz="UTC")), + ("date", pa.date32()), + # Not supported by Spark + # ("time", pa.time64("us")), + # Not natively supported by Arrow + # ("uuid", pa.fixed(16)), + ("binary", pa.large_binary()), + ("fixed", pa.binary(16)), + ]) + + +@pytest.fixture(scope="session") +def arrow_table_with_null(pa_schema: pa.Schema) -> pa.Table: + """PyArrow table with all kinds of columns""" + return pa.Table.from_pydict(TEST_DATA_WITH_NULL, schema=pa_schema) + + +@pytest.fixture(scope="session") +def arrow_table_without_data(pa_schema: pa.Schema) -> pa.Table: + """PyArrow table with all kinds of columns""" + return pa.Table.from_pylist([], schema=pa_schema) + + +@pytest.fixture(scope="session") +def arrow_table_with_only_nulls(pa_schema: pa.Schema) -> pa.Table: + """PyArrow table with all kinds of columns""" + return pa.Table.from_pylist([{}, {}], schema=pa_schema) \ No newline at end of file diff --git a/tests/integration/test_partitioned_write.py b/tests/integration/test_writes/test_partitioned_writes.py similarity index 85% rename from tests/integration/test_partitioned_write.py rename to tests/integration/test_writes/test_partitioned_writes.py index 667a826773..9e8852837c 100644 --- a/tests/integration/test_partitioned_write.py +++ b/tests/integration/test_writes/test_partitioned_writes.py @@ -51,91 +51,10 @@ TimestamptzType, ) -TEST_DATA_WITH_NULL = { - 'bool': [False, None, True], - 'string': ['a', None, 'z'], - # Go over the 16 bytes to kick in truncation - 'string_long': ['a' * 22, None, 'z' * 22], - 'int': [1, None, 9], - 'long': [1, None, 9], - 'float': [0.0, None, 0.9], - 'double': [0.0, None, 0.9], - 'timestamp': [datetime(2023, 1, 1, 19, 25, 00), None, datetime(2023, 3, 1, 19, 25, 00)], - 'timestamptz': [ - datetime(2023, 1, 1, 19, 25, 00, tzinfo=pytz.timezone('America/New_York')), - None, - datetime(2023, 3, 1, 19, 25, 00, tzinfo=pytz.timezone('America/New_York')), - ], - 'date': [date(2023, 1, 1), None, date(2023, 3, 1)], - # Not supported by Spark - # 'time': [time(1, 22, 0), None, time(19, 25, 0)], - # Not natively supported by Arrow - # 'uuid': [uuid.UUID('00000000-0000-0000-0000-000000000000').bytes, None, uuid.UUID('11111111-1111-1111-1111-111111111111').bytes], - 'binary': [b'\01', None, b'\22'], - 'fixed': [ - uuid.UUID('00000000-0000-0000-0000-000000000000').bytes, - None, - uuid.UUID('11111111-1111-1111-1111-111111111111').bytes, - ], -} - - -TABLE_SCHEMA = Schema( - NestedField(field_id=1, name="bool", field_type=BooleanType(), required=False), - NestedField(field_id=2, name="string", field_type=StringType(), required=False), - NestedField(field_id=3, name="string_long", field_type=StringType(), required=False), - NestedField(field_id=4, name="int", field_type=IntegerType(), required=False), - NestedField(field_id=5, name="long", field_type=LongType(), required=False), - NestedField(field_id=6, name="float", field_type=FloatType(), required=False), - NestedField(field_id=7, name="double", field_type=DoubleType(), required=False), - NestedField(field_id=8, name="timestamp", field_type=TimestampType(), required=False), - NestedField(field_id=9, name="timestamptz", field_type=TimestamptzType(), required=False), - NestedField(field_id=10, name="date", field_type=DateType(), required=False), - # NestedField(field_id=11, name="time", field_type=TimeType(), required=False), - # NestedField(field_id=12, name="uuid", field_type=UuidType(), required=False), - NestedField(field_id=11, name="binary", field_type=BinaryType(), required=False), - NestedField(field_id=12, name="fixed", field_type=FixedType(16), required=False), -) - - -@pytest.fixture(scope="session") -def pa_schema() -> pa.Schema: - return pa.schema([ - ("bool", pa.bool_()), - ("string", pa.string()), - ("string_long", pa.string()), - ("int", pa.int32()), - ("long", pa.int64()), - ("float", pa.float32()), - ("double", pa.float64()), - ("timestamp", pa.timestamp(unit="us")), - ("timestamptz", pa.timestamp(unit="us", tz="UTC")), - ("date", pa.date32()), - # Not supported by Spark - # ("time", pa.time64("us")), - # Not natively supported by Arrow - # ("uuid", pa.fixed(16)), - ("binary", pa.large_binary()), - ("fixed", pa.binary(16)), - ]) - - -@pytest.fixture(scope="session") -def arrow_table_with_null(pa_schema: Schema) -> pa.Table: - """PyArrow table with all kinds of columns""" - return pa.Table.from_pydict(TEST_DATA_WITH_NULL, schema=pa_schema) - +from utils import TEST_DATA_WITH_NULL, TABLE_SCHEMA -@pytest.fixture(scope="session") -def arrow_table_without_data(pa_schema: pa.Schema) -> pa.Table: - """PyArrow table with all kinds of columns""" - return pa.Table.from_pylist([], schema=pa_schema) -@pytest.fixture(scope="session") -def arrow_table_with_only_nulls(pa_schema: pa.Schema) -> pa.Table: - """PyArrow table with all kinds of columns""" - return pa.Table.from_pylist([{}, {}], schema=pa_schema) @pytest.fixture(scope="session", autouse=True) @@ -280,8 +199,8 @@ def table_v1_appended_with_null_partitioned(session_catalog: Catalog, arrow_tabl properties={'format-version': '1'}, ) - for _ in range(2): - tbl.append(arrow_table_with_null) + tbl.append(arrow_table_with_null) + tbl.append(pa.concat_tables([arrow_table_with_null, arrow_table_with_null])) assert tbl.format_version == 1, f"Expected v1, got: v{tbl.format_version}" @@ -427,8 +346,9 @@ def table_v2_appended_with_null_partitioned(session_catalog: Catalog, arrow_tabl properties={'format-version': '2'}, ) - for _ in range(2): - tbl.append(arrow_table_with_null) + tbl.append(arrow_table_with_null) + tbl.append(pa.concat_tables([arrow_table_with_null, arrow_table_with_null])) + assert tbl.format_version == 2, f"Expected v2, got: v{tbl.format_version}" @@ -519,7 +439,7 @@ def test_query_filter_only_nulls_partitioned(spark: SparkSession, part_col: str, @pytest.mark.integration @pytest.mark.parametrize( - "part_col", ['int', 'bool', 'string', "string_long", "long", "float", "double", "date", "timestamptz", "timestamp", "binary"] + "part_col", ['int',]# 'bool', 'string', "string_long", "long", "float", "double", "date", "timestamptz", "timestamp", "binary"] ) @pytest.mark.parametrize("format_version", [1, 2]) def test_query_filter_appended_null_partitioned(spark: SparkSession, part_col: str, format_version: int) -> None: @@ -527,8 +447,10 @@ def test_query_filter_appended_null_partitioned(spark: SparkSession, part_col: s df = spark.table(identifier) for col in TEST_DATA_WITH_NULL.keys(): df = spark.table(identifier) - assert df.where(f"{col} is not null").count() == 4, f"Expected 4 non-null rows for {col}" - assert df.where(f"{col} is null").count() == 2, f"Expected 2 null rows for {col}" + assert df.where(f"{col} is not null").count() == 6, f"Expected 6 non-null rows for {col}" + assert df.where(f"{col} is null").count() == 3, f"Expected 3 null rows for {col}" + rows = spark.sql(f"select partition from {identifier}.files").collect() + assert len(rows) == 6 @pytest.mark.integration @@ -544,6 +466,7 @@ def test_query_filter_v1_v2_append_null(spark: SparkSession, part_col: str) -> N assert df.where(f"{col} is null").count() == 2, f"Expected 2 null rows for {col}" +@pytest.mark.integration def test_summaries_with_null(spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None: identifier = "default.arrow_table_summaries" diff --git a/tests/integration/test_writes.py b/tests/integration/test_writes/test_writes.py similarity index 88% rename from tests/integration/test_writes.py rename to tests/integration/test_writes/test_writes.py index 5dd80588ca..b583a427aa 100644 --- a/tests/integration/test_writes.py +++ b/tests/integration/test_writes/test_writes.py @@ -54,86 +54,7 @@ TimestamptzType, ) -TEST_DATA_WITH_NULL = { - 'bool': [False, None, True], - 'string': ['a', None, 'z'], - # Go over the 16 bytes to kick in truncation - 'string_long': ['a' * 22, None, 'z' * 22], - 'int': [1, None, 9], - 'long': [1, None, 9], - 'float': [0.0, None, 0.9], - 'double': [0.0, None, 0.9], - 'timestamp': [datetime(2023, 1, 1, 19, 25, 00), None, datetime(2023, 3, 1, 19, 25, 00)], - 'timestamptz': [datetime(2023, 1, 1, 19, 25, 00), None, datetime(2023, 3, 1, 19, 25, 00)], - 'date': [date(2023, 1, 1), None, date(2023, 3, 1)], - # Not supported by Spark - # 'time': [time(1, 22, 0), None, time(19, 25, 0)], - # Not natively supported by Arrow - # 'uuid': [uuid.UUID('00000000-0000-0000-0000-000000000000').bytes, None, uuid.UUID('11111111-1111-1111-1111-111111111111').bytes], - 'binary': [b'\01', None, b'\22'], - 'fixed': [ - uuid.UUID('00000000-0000-0000-0000-000000000000').bytes, - None, - uuid.UUID('11111111-1111-1111-1111-111111111111').bytes, - ], -} - -TABLE_SCHEMA = Schema( - NestedField(field_id=1, name="bool", field_type=BooleanType(), required=False), - NestedField(field_id=2, name="string", field_type=StringType(), required=False), - NestedField(field_id=3, name="string_long", field_type=StringType(), required=False), - NestedField(field_id=4, name="int", field_type=IntegerType(), required=False), - NestedField(field_id=5, name="long", field_type=LongType(), required=False), - NestedField(field_id=6, name="float", field_type=FloatType(), required=False), - NestedField(field_id=7, name="double", field_type=DoubleType(), required=False), - NestedField(field_id=8, name="timestamp", field_type=TimestampType(), required=False), - NestedField(field_id=9, name="timestamptz", field_type=TimestamptzType(), required=False), - NestedField(field_id=10, name="date", field_type=DateType(), required=False), - # NestedField(field_id=11, name="time", field_type=TimeType(), required=False), - # NestedField(field_id=12, name="uuid", field_type=UuidType(), required=False), - NestedField(field_id=12, name="binary", field_type=BinaryType(), required=False), - NestedField(field_id=13, name="fixed", field_type=FixedType(16), required=False), -) - - -@pytest.fixture(scope="session") -def pa_schema() -> pa.Schema: - return pa.schema([ - ("bool", pa.bool_()), - ("string", pa.string()), - ("string_long", pa.string()), - ("int", pa.int32()), - ("long", pa.int64()), - ("float", pa.float32()), - ("double", pa.float64()), - ("timestamp", pa.timestamp(unit="us")), - ("timestamptz", pa.timestamp(unit="us", tz="UTC")), - ("date", pa.date32()), - # Not supported by Spark - # ("time", pa.time64("us")), - # Not natively supported by Arrow - # ("uuid", pa.fixed(16)), - ("binary", pa.large_binary()), - ("fixed", pa.binary(16)), - ]) - - -@pytest.fixture(scope="session") -def arrow_table_with_null(pa_schema: pa.Schema) -> pa.Table: - """PyArrow table with all kinds of columns""" - return pa.Table.from_pydict(TEST_DATA_WITH_NULL, schema=pa_schema) - - -@pytest.fixture(scope="session") -def arrow_table_without_data(pa_schema: pa.Schema) -> pa.Table: - """PyArrow table with all kinds of columns""" - return pa.Table.from_pylist([], schema=pa_schema) - - -@pytest.fixture(scope="session") -def arrow_table_with_only_nulls(pa_schema: pa.Schema) -> pa.Table: - """PyArrow table with all kinds of columns""" - return pa.Table.from_pylist([{}, {}], schema=pa_schema) +from utils import TEST_DATA_WITH_NULL, TABLE_SCHEMA def _create_table( @@ -222,7 +143,6 @@ def table_v1_v2_appended_with_null(session_catalog: Catalog, arrow_table_with_nu assert tbl.format_version == 2, f"Expected v2, got: v{tbl.format_version}" - @pytest.mark.integration @pytest.mark.parametrize("format_version", [1, 2]) def test_query_count(spark: SparkSession, format_version: int) -> None: @@ -230,8 +150,9 @@ def test_query_count(spark: SparkSession, format_version: int) -> None: assert df.count() == 3, "Expected 3 rows" +from utils import TEST_DATA_WITH_NULL @pytest.mark.integration -@pytest.mark.parametrize("col", TEST_DATA_WITH_NULL.keys()) +@pytest.mark.parametrize("col", ["int"]) #TEST_DATA_WITH_NULL.keys()) @pytest.mark.parametrize("format_version", [1, 2]) def test_query_filter_null(spark: SparkSession, col: str, format_version: int) -> None: identifier = f"default.arrow_table_v{format_version}_with_null" diff --git a/tests/integration/test_writes/utils.py b/tests/integration/test_writes/utils.py new file mode 100644 index 0000000000..581d4fe222 --- /dev/null +++ b/tests/integration/test_writes/utils.py @@ -0,0 +1,69 @@ +from datetime import datetime, date, timezone +import uuid +import pytz +import pytest +import pyarrow as pa + +from pyiceberg.schema import Schema +from pyiceberg.types import ( + BinaryType, + BooleanType, + DateType, + DoubleType, + FixedType, + FloatType, + IntegerType, + LongType, + NestedField, + StringType, + TimestampType, + TimestamptzType, +) +TEST_DATA_WITH_NULL = { + 'bool': [False, None, True], + 'string': ['a', None, 'z'], + # Go over the 16 bytes to kick in truncation + 'string_long': ['a' * 22, None, 'z' * 22], + 'int': [1, None, 9], + 'long': [1, None, 9], + 'float': [0.0, None, 0.9], + 'double': [0.0, None, 0.9], + 'timestamp': [datetime(2023, 1, 1, 19, 25, 00), None, datetime(2023, 3, 1, 19, 25, 00)], + # 'timestamptz': [datetime(2023, 1, 1, 19, 25, 00), None, datetime(2023, 3, 1, 19, 25, 00)], + 'timestamptz': [ + datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc), + None, + datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc), + ], + 'date': [date(2023, 1, 1), None, date(2023, 3, 1)], + # Not supported by Spark + # 'time': [time(1, 22, 0), None, time(19, 25, 0)], + # Not natively supported by Arrow + # 'uuid': [uuid.UUID('00000000-0000-0000-0000-000000000000').bytes, None, uuid.UUID('11111111-1111-1111-1111-111111111111').bytes], + 'binary': [b'\01', None, b'\22'], + 'fixed': [ + uuid.UUID('00000000-0000-0000-0000-000000000000').bytes, + None, + uuid.UUID('11111111-1111-1111-1111-111111111111').bytes, + ], +} + + +TABLE_SCHEMA = Schema( + NestedField(field_id=1, name="bool", field_type=BooleanType(), required=False), + NestedField(field_id=2, name="string", field_type=StringType(), required=False), + NestedField(field_id=3, name="string_long", field_type=StringType(), required=False), + NestedField(field_id=4, name="int", field_type=IntegerType(), required=False), + NestedField(field_id=5, name="long", field_type=LongType(), required=False), + NestedField(field_id=6, name="float", field_type=FloatType(), required=False), + NestedField(field_id=7, name="double", field_type=DoubleType(), required=False), + NestedField(field_id=8, name="timestamp", field_type=TimestampType(), required=False), + NestedField(field_id=9, name="timestamptz", field_type=TimestamptzType(), required=False), + NestedField(field_id=10, name="date", field_type=DateType(), required=False), + # NestedField(field_id=11, name="time", field_type=TimeType(), required=False), + # NestedField(field_id=12, name="uuid", field_type=UuidType(), required=False), + NestedField(field_id=11, name="binary", field_type=BinaryType(), required=False), + NestedField(field_id=12, name="fixed", field_type=FixedType(16), required=False), +) + + From 7595b6b902e3b3d19671e83b460b1b33b63267e6 Mon Sep 17 00:00:00 2001 From: Adrian Qin <147659252+jqin61@users.noreply.github.com> Date: Mon, 1 Apr 2024 18:32:03 +0000 Subject: [PATCH 08/15] pull partitioned table fixtures into tests for more real-time feedback of running test --- tests/integration/test_writes/conftest.py | 12 +- .../test_writes/test_partitioned_writes.py | 490 +++++------------- tests/integration/test_writes/test_writes.py | 45 +- tests/integration/test_writes/utils.py | 38 +- 4 files changed, 164 insertions(+), 421 deletions(-) diff --git a/tests/integration/test_writes/conftest.py b/tests/integration/test_writes/conftest.py index 4f98c20660..0c12407d2a 100644 --- a/tests/integration/test_writes/conftest.py +++ b/tests/integration/test_writes/conftest.py @@ -1,7 +1,9 @@ import pyarrow as pa -from utils import TEST_DATA_WITH_NULL import pytest +from utils import TEST_DATA_WITH_NULL + + @pytest.fixture(scope="session") def pa_schema() -> pa.Schema: return pa.schema([ @@ -26,17 +28,17 @@ def pa_schema() -> pa.Schema: @pytest.fixture(scope="session") def arrow_table_with_null(pa_schema: pa.Schema) -> pa.Table: - """PyArrow table with all kinds of columns""" + """PyArrow table with all kinds of columns.""" return pa.Table.from_pydict(TEST_DATA_WITH_NULL, schema=pa_schema) @pytest.fixture(scope="session") def arrow_table_without_data(pa_schema: pa.Schema) -> pa.Table: - """PyArrow table with all kinds of columns""" + """PyArrow table with all kinds of columns.""" return pa.Table.from_pylist([], schema=pa_schema) @pytest.fixture(scope="session") def arrow_table_with_only_nulls(pa_schema: pa.Schema) -> pa.Table: - """PyArrow table with all kinds of columns""" - return pa.Table.from_pylist([{}, {}], schema=pa_schema) \ No newline at end of file + """PyArrow table with all kinds of columns.""" + return pa.Table.from_pylist([{}, {}], schema=pa_schema) diff --git a/tests/integration/test_writes/test_partitioned_writes.py b/tests/integration/test_writes/test_partitioned_writes.py index 9e8852837c..0d1b9ca3fe 100644 --- a/tests/integration/test_writes/test_partitioned_writes.py +++ b/tests/integration/test_writes/test_partitioned_writes.py @@ -15,18 +15,14 @@ # specific language governing permissions and limitations # under the License. # pylint:disable=redefined-outer-name -import uuid -from datetime import date, datetime import pyarrow as pa import pytest -import pytz from pyspark.sql import SparkSession from pyiceberg.catalog import Catalog from pyiceberg.exceptions import NoSuchTableError from pyiceberg.partitioning import PartitionField, PartitionSpec -from pyiceberg.schema import Schema from pyiceberg.transforms import ( BucketTransform, DayTransform, @@ -36,374 +32,35 @@ TruncateTransform, YearTransform, ) -from pyiceberg.types import ( - BinaryType, - BooleanType, - DateType, - DoubleType, - FixedType, - FloatType, - IntegerType, - LongType, - NestedField, - StringType, - TimestampType, - TimestamptzType, -) - -from utils import TEST_DATA_WITH_NULL, TABLE_SCHEMA - - - - - -@pytest.fixture(scope="session", autouse=True) -def table_v1_with_null_partitioned(session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None: - partition_cols = [ - 'int', - 'bool', - 'string', - "string_long", - "long", - "float", - "double", - "date", - "timestamptz", - "timestamp", - "binary", - ] - for partition_col in partition_cols: - identifier = f"default.arrow_table_v1_with_null_partitioned_on_col_{partition_col}" - - try: - session_catalog.drop_table(identifier=identifier) - except NoSuchTableError: - pass - nested_field = TABLE_SCHEMA.find_field(partition_col) - source_id = nested_field.field_id - tbl = session_catalog.create_table( - identifier=identifier, - schema=TABLE_SCHEMA, - partition_spec=PartitionSpec( - PartitionField(source_id=source_id, field_id=1001, transform=IdentityTransform(), name=partition_col) - ), - properties={'format-version': '1'}, - ) - tbl.append(arrow_table_with_null) - - assert tbl.format_version == 1, f"Expected v1, got: v{tbl.format_version}" - - -@pytest.fixture(scope="session", autouse=True) -def table_v1_without_data_partitioned(session_catalog: Catalog, arrow_table_without_data: pa.Table) -> None: - partition_cols = [ - 'int', - 'bool', - 'string', - "string_long", - "long", - "float", - "double", - "date", - "timestamptz", - "timestamp", - "binary", - ] - for partition_col in partition_cols: - identifier = f"default.arrow_table_v1_without_data_partitioned_on_col_{partition_col}" - - try: - session_catalog.drop_table(identifier=identifier) - except NoSuchTableError: - pass - nested_field = TABLE_SCHEMA.find_field(partition_col) - source_id = nested_field.field_id - tbl = session_catalog.create_table( - identifier=identifier, - schema=TABLE_SCHEMA, - partition_spec=PartitionSpec( - PartitionField(source_id=source_id, field_id=1001, transform=IdentityTransform(), name=partition_col) - ), - properties={'format-version': '1'}, - ) - tbl.append(arrow_table_without_data) - assert tbl.format_version == 1, f"Expected v1, got: v{tbl.format_version}" - - -@pytest.fixture(scope="session", autouse=True) -def table_v1_with_only_nulls_partitioned(session_catalog: Catalog, arrow_table_with_only_nulls: pa.Table) -> None: - partition_cols = [ - 'int', - 'bool', - 'string', - "string_long", - "long", - "float", - "double", - "date", - "timestamptz", - "timestamp", - "binary", - ] - for partition_col in partition_cols: - identifier = f"default.arrow_table_v1_with_only_nulls_partitioned_on_col_{partition_col}" - try: - session_catalog.drop_table(identifier=identifier) - except NoSuchTableError: - pass - nested_field = TABLE_SCHEMA.find_field(partition_col) - source_id = nested_field.field_id - tbl = session_catalog.create_table( - identifier=identifier, - schema=TABLE_SCHEMA, - partition_spec=PartitionSpec( - PartitionField(source_id=source_id, field_id=1001, transform=IdentityTransform(), name=partition_col) - ), - properties={'format-version': '1'}, - ) - tbl.append(arrow_table_with_only_nulls) - assert tbl.format_version == 1, f"Expected v1, got: v{tbl.format_version}" - - -@pytest.fixture(scope="session", autouse=True) -def table_v1_appended_with_null_partitioned(session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None: - partition_cols = [ - 'int', - 'bool', - 'string', - "string_long", - "long", - "float", - "double", - "date", - "timestamptz", - "timestamp", - "binary", - ] - for partition_col in partition_cols: - identifier = f"default.arrow_table_v1_appended_with_null_partitioned_on_col_{partition_col}" - - try: - session_catalog.drop_table(identifier=identifier) - except NoSuchTableError: - pass - - nested_field = TABLE_SCHEMA.find_field(partition_col) - source_id = nested_field.field_id - tbl = session_catalog.create_table( - identifier=identifier, - schema=TABLE_SCHEMA, - partition_spec=PartitionSpec( - PartitionField(source_id=source_id, field_id=1001, transform=IdentityTransform(), name=partition_col) - ), # name has to be the same for identity transform - properties={'format-version': '1'}, - ) - - tbl.append(arrow_table_with_null) - tbl.append(pa.concat_tables([arrow_table_with_null, arrow_table_with_null])) - assert tbl.format_version == 1, f"Expected v1, got: v{tbl.format_version}" - - -@pytest.fixture(scope="session", autouse=True) -def table_v2_with_null_partitioned(session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None: - partition_cols = [ - 'int', - 'bool', - 'string', - "string_long", - "long", - "float", - "double", - "date", - "timestamptz", - "timestamp", - "binary", - ] - for partition_col in partition_cols: - identifier = f"default.arrow_table_v2_with_null_partitioned_on_col_{partition_col}" - - try: - session_catalog.drop_table(identifier=identifier) - except NoSuchTableError: - pass - nested_field = TABLE_SCHEMA.find_field(partition_col) - source_id = nested_field.field_id - tbl = session_catalog.create_table( - identifier=identifier, - schema=TABLE_SCHEMA, - partition_spec=PartitionSpec( - PartitionField(source_id=source_id, field_id=1001, transform=IdentityTransform(), name=partition_col) - ), - properties={'format-version': '2'}, - ) - tbl.append(arrow_table_with_null) - - assert tbl.format_version == 2, f"Expected v2, got: v{tbl.format_version}" - - -@pytest.fixture(scope="session", autouse=True) -def table_v2_without_data_partitioned(session_catalog: Catalog, arrow_table_without_data: pa.Table) -> None: - partition_cols = [ - 'int', - 'bool', - 'string', - "string_long", - "long", - "float", - "double", - "date", - "timestamptz", - "timestamp", - "binary", - ] - for partition_col in partition_cols: - identifier = f"default.arrow_table_v2_without_data_partitioned_on_col_{partition_col}" - - try: - session_catalog.drop_table(identifier=identifier) - except NoSuchTableError: - pass - nested_field = TABLE_SCHEMA.find_field(partition_col) - source_id = nested_field.field_id - tbl = session_catalog.create_table( - identifier=identifier, - schema=TABLE_SCHEMA, - partition_spec=PartitionSpec( - PartitionField(source_id=source_id, field_id=1001, transform=IdentityTransform(), name=partition_col) - ), - properties={'format-version': '2'}, - ) - tbl.append(arrow_table_without_data) - assert tbl.format_version == 2, f"Expected v2, got: v{tbl.format_version}" - - -@pytest.fixture(scope="session", autouse=True) -def table_v2_with_only_nulls_partitioned(session_catalog: Catalog, arrow_table_with_only_nulls: pa.Table) -> None: - partition_cols = [ - 'int', - 'bool', - 'string', - "string_long", - "long", - "float", - "double", - "date", - "timestamptz", - "timestamp", - "binary", - ] - for partition_col in partition_cols: - identifier = f"default.arrow_table_v2_with_only_nulls_partitioned_on_col_{partition_col}" - try: - session_catalog.drop_table(identifier=identifier) - except NoSuchTableError: - pass - nested_field = TABLE_SCHEMA.find_field(partition_col) - source_id = nested_field.field_id - tbl = session_catalog.create_table( - identifier=identifier, - schema=TABLE_SCHEMA, - partition_spec=PartitionSpec( - PartitionField(source_id=source_id, field_id=1001, transform=IdentityTransform(), name=partition_col) - ), - properties={'format-version': '2'}, - ) - tbl.append(arrow_table_with_only_nulls) - assert tbl.format_version == 2, f"Expected v2, got: v{tbl.format_version}" - - -@pytest.fixture(scope="session", autouse=True) -def table_v2_appended_with_null_partitioned(session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None: - partition_cols = [ - 'int', - 'bool', - 'string', - "string_long", - "long", - "float", - "double", - "date", - "timestamptz", - "timestamp", - "binary", - ] - for partition_col in partition_cols: - identifier = f"default.arrow_table_v2_appended_with_null_partitioned_on_col_{partition_col}" - - try: - session_catalog.drop_table(identifier=identifier) - except NoSuchTableError: - pass - - nested_field = TABLE_SCHEMA.find_field(partition_col) - source_id = nested_field.field_id - tbl = session_catalog.create_table( - identifier=identifier, - schema=TABLE_SCHEMA, - partition_spec=PartitionSpec( - PartitionField(source_id=source_id, field_id=1001, transform=IdentityTransform(), name=partition_col) - ), - properties={'format-version': '2'}, - ) - - tbl.append(arrow_table_with_null) - tbl.append(pa.concat_tables([arrow_table_with_null, arrow_table_with_null])) - - assert tbl.format_version == 2, f"Expected v2, got: v{tbl.format_version}" - - -@pytest.fixture(scope="session", autouse=True) -def table_v1_v2_appended_with_null(session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None: - partition_cols = [ - 'int', - 'bool', - 'string', - "string_long", - "long", - "float", - "double", - "date", - "timestamptz", - "timestamp", - "binary", - ] - for partition_col in partition_cols: - identifier = f"default.arrow_table_v1_v2_appended_with_null_partitioned_on_col_{partition_col}" - - try: - session_catalog.drop_table(identifier=identifier) - except NoSuchTableError: - pass - - nested_field = TABLE_SCHEMA.find_field(partition_col) - source_id = nested_field.field_id - tbl = session_catalog.create_table( - identifier=identifier, - schema=TABLE_SCHEMA, - partition_spec=PartitionSpec( - PartitionField(source_id=source_id, field_id=1001, transform=IdentityTransform(), name=partition_col) - ), - properties={'format-version': '1'}, - ) - tbl.append(arrow_table_with_null) - - assert tbl.format_version == 1, f"Expected v1, got: v{tbl.format_version}" - - with tbl.transaction() as tx: - tx.upgrade_table_version(format_version=2) - - tbl.append(arrow_table_with_null) - - assert tbl.format_version == 2, f"Expected v2, got: v{tbl.format_version}" +from utils import TABLE_SCHEMA, TEST_DATA_WITH_NULL, _create_table @pytest.mark.integration @pytest.mark.parametrize( - "part_col", ['int', 'bool', 'string', "string_long", "long", "float", "double", "date", 'timestamp', 'timestamptz'] + "part_col", ['int', 'bool', 'string', "string_long", "long", "float", "double", "date", 'timestamp', 'timestamptz', 'binary'] ) @pytest.mark.parametrize("format_version", [1, 2]) -def test_query_filter_null_partitioned(spark: SparkSession, part_col: str, format_version: int) -> None: +def test_query_filter_null_partitioned( + session_catalog: Catalog, spark: SparkSession, arrow_table_with_null: pa.Table, part_col: str, format_version: int +) -> None: + # Given identifier = f"default.arrow_table_v{format_version}_with_null_partitioned_on_col_{part_col}" + nested_field = TABLE_SCHEMA.find_field(part_col) + partition_spec = PartitionSpec( + PartitionField(source_id=nested_field.field_id, field_id=1001, transform=IdentityTransform(), name=part_col) + ) + + # When + tbl = _create_table( + session_catalog=session_catalog, + identifier=identifier, + properties={"format-version": str(format_version)}, + data=[arrow_table_with_null], + partition_spec=partition_spec, + ) + + # Then + assert tbl.format_version == format_version, f"Expected v{format_version}, got: v{tbl.format_version}" df = spark.table(identifier) assert df.count() == 3, f"Expected 3 total rows for {identifier}" for col in TEST_DATA_WITH_NULL.keys(): @@ -413,11 +70,30 @@ def test_query_filter_null_partitioned(spark: SparkSession, part_col: str, forma @pytest.mark.integration @pytest.mark.parametrize( - "part_col", ['int', 'bool', 'string', "string_long", "long", "float", "double", "date", 'timestamp', 'timestamptz'] + "part_col", ['int', 'bool', 'string', "string_long", "long", "float", "double", "date", 'timestamp', 'timestamptz', 'binary'] ) @pytest.mark.parametrize("format_version", [1, 2]) -def test_query_filter_without_data_partitioned(spark: SparkSession, part_col: str, format_version: int) -> None: +def test_query_filter_without_data_partitioned( + session_catalog: Catalog, spark: SparkSession, arrow_table_without_data: pa.Table, part_col: str, format_version: int +) -> None: + # Given identifier = f"default.arrow_table_v{format_version}_without_data_partitioned_on_col_{part_col}" + nested_field = TABLE_SCHEMA.find_field(part_col) + partition_spec = PartitionSpec( + PartitionField(source_id=nested_field.field_id, field_id=1001, transform=IdentityTransform(), name=part_col) + ) + + # When + tbl = _create_table( + session_catalog=session_catalog, + identifier=identifier, + properties={"format-version": str(format_version)}, + data=[arrow_table_without_data], + partition_spec=partition_spec, + ) + + # Then + assert tbl.format_version == format_version, f"Expected v{format_version}, got: v{tbl.format_version}" df = spark.table(identifier) for col in TEST_DATA_WITH_NULL.keys(): assert df.where(f"{col} is null").count() == 0, f"Expected 0 row for {col}" @@ -426,11 +102,30 @@ def test_query_filter_without_data_partitioned(spark: SparkSession, part_col: st @pytest.mark.integration @pytest.mark.parametrize( - "part_col", ['int', 'bool', 'string', "string_long", "long", "float", "double", "date", 'timestamp', 'timestamptz'] + "part_col", ['int', 'bool', 'string', "string_long", "long", "float", "double", "date", 'timestamp', 'timestamptz', 'binary'] ) @pytest.mark.parametrize("format_version", [1, 2]) -def test_query_filter_only_nulls_partitioned(spark: SparkSession, part_col: str, format_version: int) -> None: +def test_query_filter_only_nulls_partitioned( + session_catalog: Catalog, spark: SparkSession, arrow_table_with_only_nulls: pa.Table, part_col: str, format_version: int +) -> None: + # Given identifier = f"default.arrow_table_v{format_version}_with_only_nulls_partitioned_on_col_{part_col}" + nested_field = TABLE_SCHEMA.find_field(part_col) + partition_spec = PartitionSpec( + PartitionField(source_id=nested_field.field_id, field_id=1001, transform=IdentityTransform(), name=part_col) + ) + + # When + tbl = _create_table( + session_catalog=session_catalog, + identifier=identifier, + properties={"format-version": str(format_version)}, + data=[arrow_table_with_only_nulls], + partition_spec=partition_spec, + ) + + # Then + assert tbl.format_version == format_version, f"Expected v{format_version}, got: v{tbl.format_version}" df = spark.table(identifier) for col in TEST_DATA_WITH_NULL.keys(): assert df.where(f"{col} is null").count() == 2, f"Expected 2 row for {col}" @@ -439,16 +134,39 @@ def test_query_filter_only_nulls_partitioned(spark: SparkSession, part_col: str, @pytest.mark.integration @pytest.mark.parametrize( - "part_col", ['int',]# 'bool', 'string', "string_long", "long", "float", "double", "date", "timestamptz", "timestamp", "binary"] + "part_col", ['int', 'bool', 'string', "string_long", "long", "float", "double", "date", "timestamptz", "timestamp", "binary"] ) @pytest.mark.parametrize("format_version", [1, 2]) -def test_query_filter_appended_null_partitioned(spark: SparkSession, part_col: str, format_version: int) -> None: +def test_query_filter_appended_null_partitioned( + session_catalog: Catalog, spark: SparkSession, arrow_table_with_null: pa.Table, part_col: str, format_version: int +) -> None: + # Given identifier = f"default.arrow_table_v{format_version}_appended_with_null_partitioned_on_col_{part_col}" + nested_field = TABLE_SCHEMA.find_field(part_col) + partition_spec = PartitionSpec( + PartitionField(source_id=nested_field.field_id, field_id=1001, transform=IdentityTransform(), name=part_col) + ) + + # When + tbl = _create_table( + session_catalog=session_catalog, + identifier=identifier, + properties={"format-version": str(format_version)}, + data=[], + partition_spec=partition_spec, + ) + # Append with arrow_table_1 with lines [A,B,C] and then arrow_table_2 with lines[A,B,C,A,B,C] + tbl.append(arrow_table_with_null) + tbl.append(pa.concat_tables([arrow_table_with_null, arrow_table_with_null])) + + # Then + assert tbl.format_version == format_version, f"Expected v{format_version}, got: v{tbl.format_version}" df = spark.table(identifier) for col in TEST_DATA_WITH_NULL.keys(): df = spark.table(identifier) assert df.where(f"{col} is not null").count() == 6, f"Expected 6 non-null rows for {col}" assert df.where(f"{col} is null").count() == 3, f"Expected 3 null rows for {col}" + # expecting 6 files: first append with [A], [B], [C], second append with [A, A], [B, B], [C, C] rows = spark.sql(f"select partition from {identifier}.files").collect() assert len(rows) == 6 @@ -457,8 +175,36 @@ def test_query_filter_appended_null_partitioned(spark: SparkSession, part_col: s @pytest.mark.parametrize( "part_col", ['int', 'bool', 'string', "string_long", "long", "float", "double", "date", "timestamptz", "timestamp", "binary"] ) -def test_query_filter_v1_v2_append_null(spark: SparkSession, part_col: str) -> None: +def test_query_filter_v1_v2_append_null( + session_catalog: Catalog, spark: SparkSession, arrow_table_with_null: pa.Table, part_col: str +) -> None: + # Given identifier = f"default.arrow_table_v1_v2_appended_with_null_partitioned_on_col_{part_col}" + nested_field = TABLE_SCHEMA.find_field(part_col) + partition_spec = PartitionSpec( + PartitionField(source_id=nested_field.field_id, field_id=1001, transform=IdentityTransform(), name=part_col) + ) + + # When + tbl = _create_table( + session_catalog=session_catalog, + identifier=identifier, + properties={"format-version": "1"}, + data=[], + partition_spec=partition_spec, + ) + tbl.append(arrow_table_with_null) + + # Then + assert tbl.format_version == 1, f"Expected v1, got: v{tbl.format_version}" + + # When + with tbl.transaction() as tx: + tx.upgrade_table_version(format_version=2) + tbl.append(arrow_table_with_null) + + # Then + assert tbl.format_version == 2, f"Expected v2, got: v{tbl.format_version}" df = spark.table(identifier) for col in TEST_DATA_WITH_NULL.keys(): df = spark.table(identifier) diff --git a/tests/integration/test_writes/test_writes.py b/tests/integration/test_writes/test_writes.py index b583a427aa..996d166d77 100644 --- a/tests/integration/test_writes/test_writes.py +++ b/tests/integration/test_writes/test_writes.py @@ -18,10 +18,9 @@ import math import os import time -import uuid from datetime import date, datetime from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import Any, Dict from urllib.parse import urlparse import pyarrow as pa @@ -36,42 +35,8 @@ from pyiceberg.catalog import Catalog from pyiceberg.catalog.sql import SqlCatalog from pyiceberg.exceptions import NoSuchTableError -from pyiceberg.schema import Schema -from pyiceberg.table import Table, TableProperties, _dataframe_to_data_files -from pyiceberg.typedef import Properties -from pyiceberg.types import ( - BinaryType, - BooleanType, - DateType, - DoubleType, - FixedType, - FloatType, - IntegerType, - LongType, - NestedField, - StringType, - TimestampType, - TimestamptzType, -) - -from utils import TEST_DATA_WITH_NULL, TABLE_SCHEMA - - -def _create_table( - session_catalog: Catalog, identifier: str, properties: Properties, data: Optional[List[pa.Table]] = None -) -> Table: - try: - session_catalog.drop_table(identifier=identifier) - except NoSuchTableError: - pass - - tbl = session_catalog.create_table(identifier=identifier, schema=TABLE_SCHEMA, properties=properties) - - if data: - for d in data: - tbl.append(d) - - return tbl +from pyiceberg.table import TableProperties, _dataframe_to_data_files +from utils import TEST_DATA_WITH_NULL, _create_table @pytest.fixture(scope="session", autouse=True) @@ -143,6 +108,7 @@ def table_v1_v2_appended_with_null(session_catalog: Catalog, arrow_table_with_nu assert tbl.format_version == 2, f"Expected v2, got: v{tbl.format_version}" + @pytest.mark.integration @pytest.mark.parametrize("format_version", [1, 2]) def test_query_count(spark: SparkSession, format_version: int) -> None: @@ -150,9 +116,8 @@ def test_query_count(spark: SparkSession, format_version: int) -> None: assert df.count() == 3, "Expected 3 rows" -from utils import TEST_DATA_WITH_NULL @pytest.mark.integration -@pytest.mark.parametrize("col", ["int"]) #TEST_DATA_WITH_NULL.keys()) +@pytest.mark.parametrize("col", ["int"]) # TEST_DATA_WITH_NULL.keys()) @pytest.mark.parametrize("format_version", [1, 2]) def test_query_filter_null(spark: SparkSession, col: str, format_version: int) -> None: identifier = f"default.arrow_table_v{format_version}_with_null" diff --git a/tests/integration/test_writes/utils.py b/tests/integration/test_writes/utils.py index 581d4fe222..cbd04a66fc 100644 --- a/tests/integration/test_writes/utils.py +++ b/tests/integration/test_writes/utils.py @@ -1,10 +1,15 @@ -from datetime import datetime, date, timezone -import uuid -import pytz -import pytest +import uuid +from datetime import date, datetime, timezone +from typing import List, Optional + import pyarrow as pa +from pyiceberg.catalog import Catalog +from pyiceberg.exceptions import NoSuchTableError +from pyiceberg.partitioning import PartitionSpec from pyiceberg.schema import Schema +from pyiceberg.table import Table +from pyiceberg.typedef import Properties from pyiceberg.types import ( BinaryType, BooleanType, @@ -19,6 +24,7 @@ TimestampType, TimestamptzType, ) + TEST_DATA_WITH_NULL = { 'bool': [False, None, True], 'string': ['a', None, 'z'], @@ -67,3 +73,27 @@ ) +def _create_table( + session_catalog: Catalog, + identifier: str, + properties: Properties, + data: Optional[List[pa.Table]] = None, + partition_spec: Optional[PartitionSpec] = None, +) -> Table: + try: + session_catalog.drop_table(identifier=identifier) + except NoSuchTableError: + pass + + if partition_spec: + tbl = session_catalog.create_table( + identifier=identifier, schema=TABLE_SCHEMA, properties=properties, partition_spec=partition_spec + ) + else: + tbl = session_catalog.create_table(identifier=identifier, schema=TABLE_SCHEMA, properties=properties) + + if data: + for d in data: + tbl.append(d) + + return tbl From 9b371c06397bdfa0c24f77af02a11a7d00e126b0 Mon Sep 17 00:00:00 2001 From: Adrian Qin <147659252+jqin61@users.noreply.github.com> Date: Mon, 1 Apr 2024 19:47:54 +0000 Subject: [PATCH 09/15] fix linting --- tests/conftest.py | 24 +++++++++- tests/integration/test_writes/conftest.py | 44 ------------------- .../test_writes/test_partitioned_writes.py | 9 ++-- tests/integration/test_writes/test_writes.py | 5 ++- tests/integration/test_writes/utils.py | 32 -------------- 5 files changed, 30 insertions(+), 84 deletions(-) delete mode 100644 tests/integration/test_writes/conftest.py diff --git a/tests/conftest.py b/tests/conftest.py index d0f0d5920a..e7145a02b0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -30,7 +30,7 @@ import socket import string import uuid -from datetime import date, datetime +from datetime import date, datetime, timezone from pathlib import Path from random import choice from tempfile import TemporaryDirectory @@ -1999,7 +1999,11 @@ def spark() -> SparkSession: 'float': [0.0, None, 0.9], 'double': [0.0, None, 0.9], 'timestamp': [datetime(2023, 1, 1, 19, 25, 00), None, datetime(2023, 3, 1, 19, 25, 00)], - 'timestamptz': [datetime(2023, 1, 1, 19, 25, 00), None, datetime(2023, 3, 1, 19, 25, 00)], + 'timestamptz': [ + datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc), + None, + datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc), + ], 'date': [date(2023, 1, 1), None, date(2023, 3, 1)], # Not supported by Spark # 'time': [time(1, 22, 0), None, time(19, 25, 0)], @@ -2044,3 +2048,19 @@ def arrow_table_with_null(pa_schema: "pa.Schema") -> "pa.Table": """PyArrow table with all kinds of columns""" return pa.Table.from_pydict(TEST_DATA_WITH_NULL, schema=pa_schema) + + +@pytest.fixture(scope="session") +def arrow_table_without_data(pa_schema: "pa.Schema") -> "pa.Table": + import pyarrow as pa + + """PyArrow table with all kinds of columns.""" + return pa.Table.from_pylist([], schema=pa_schema) + + +@pytest.fixture(scope="session") +def arrow_table_with_only_nulls(pa_schema: "pa.Schema") -> "pa.Table": + import pyarrow as pa + + """PyArrow table with all kinds of columns.""" + return pa.Table.from_pylist([{}, {}], schema=pa_schema) diff --git a/tests/integration/test_writes/conftest.py b/tests/integration/test_writes/conftest.py deleted file mode 100644 index 0c12407d2a..0000000000 --- a/tests/integration/test_writes/conftest.py +++ /dev/null @@ -1,44 +0,0 @@ -import pyarrow as pa -import pytest - -from utils import TEST_DATA_WITH_NULL - - -@pytest.fixture(scope="session") -def pa_schema() -> pa.Schema: - return pa.schema([ - ("bool", pa.bool_()), - ("string", pa.string()), - ("string_long", pa.string()), - ("int", pa.int32()), - ("long", pa.int64()), - ("float", pa.float32()), - ("double", pa.float64()), - ("timestamp", pa.timestamp(unit="us")), - ("timestamptz", pa.timestamp(unit="us", tz="UTC")), - ("date", pa.date32()), - # Not supported by Spark - # ("time", pa.time64("us")), - # Not natively supported by Arrow - # ("uuid", pa.fixed(16)), - ("binary", pa.large_binary()), - ("fixed", pa.binary(16)), - ]) - - -@pytest.fixture(scope="session") -def arrow_table_with_null(pa_schema: pa.Schema) -> pa.Table: - """PyArrow table with all kinds of columns.""" - return pa.Table.from_pydict(TEST_DATA_WITH_NULL, schema=pa_schema) - - -@pytest.fixture(scope="session") -def arrow_table_without_data(pa_schema: pa.Schema) -> pa.Table: - """PyArrow table with all kinds of columns.""" - return pa.Table.from_pylist([], schema=pa_schema) - - -@pytest.fixture(scope="session") -def arrow_table_with_only_nulls(pa_schema: pa.Schema) -> pa.Table: - """PyArrow table with all kinds of columns.""" - return pa.Table.from_pylist([{}, {}], schema=pa_schema) diff --git a/tests/integration/test_writes/test_partitioned_writes.py b/tests/integration/test_writes/test_partitioned_writes.py index 0d1b9ca3fe..b5faca5b77 100644 --- a/tests/integration/test_writes/test_partitioned_writes.py +++ b/tests/integration/test_writes/test_partitioned_writes.py @@ -32,7 +32,8 @@ TruncateTransform, YearTransform, ) -from utils import TABLE_SCHEMA, TEST_DATA_WITH_NULL, _create_table +from tests.conftest import TEST_DATA_WITH_NULL +from utils import TABLE_SCHEMA, _create_table @pytest.mark.integration @@ -171,7 +172,7 @@ def test_query_filter_appended_null_partitioned( assert len(rows) == 6 -@pytest.mark.integration +@pytest.mark.newyork @pytest.mark.parametrize( "part_col", ['int', 'bool', 'string', "string_long", "long", "float", "double", "date", "timestamptz", "timestamp", "binary"] ) @@ -201,12 +202,12 @@ def test_query_filter_v1_v2_append_null( # When with tbl.transaction() as tx: tx.upgrade_table_version(format_version=2) + tbl.append(arrow_table_with_null) # Then assert tbl.format_version == 2, f"Expected v2, got: v{tbl.format_version}" - df = spark.table(identifier) - for col in TEST_DATA_WITH_NULL.keys(): + for col in TEST_DATA_WITH_NULL.keys(): # type: ignore df = spark.table(identifier) assert df.where(f"{col} is not null").count() == 4, f"Expected 4 non-null rows for {col}" assert df.where(f"{col} is null").count() == 2, f"Expected 2 null rows for {col}" diff --git a/tests/integration/test_writes/test_writes.py b/tests/integration/test_writes/test_writes.py index 996d166d77..2036caec90 100644 --- a/tests/integration/test_writes/test_writes.py +++ b/tests/integration/test_writes/test_writes.py @@ -36,7 +36,8 @@ from pyiceberg.catalog.sql import SqlCatalog from pyiceberg.exceptions import NoSuchTableError from pyiceberg.table import TableProperties, _dataframe_to_data_files -from utils import TEST_DATA_WITH_NULL, _create_table +from tests.conftest import TEST_DATA_WITH_NULL +from utils import _create_table @pytest.fixture(scope="session", autouse=True) @@ -117,7 +118,7 @@ def test_query_count(spark: SparkSession, format_version: int) -> None: @pytest.mark.integration -@pytest.mark.parametrize("col", ["int"]) # TEST_DATA_WITH_NULL.keys()) +@pytest.mark.parametrize("col", TEST_DATA_WITH_NULL.keys()) @pytest.mark.parametrize("format_version", [1, 2]) def test_query_filter_null(spark: SparkSession, col: str, format_version: int) -> None: identifier = f"default.arrow_table_v{format_version}_with_null" diff --git a/tests/integration/test_writes/utils.py b/tests/integration/test_writes/utils.py index cbd04a66fc..b88db8a137 100644 --- a/tests/integration/test_writes/utils.py +++ b/tests/integration/test_writes/utils.py @@ -1,5 +1,3 @@ -import uuid -from datetime import date, datetime, timezone from typing import List, Optional import pyarrow as pa @@ -25,36 +23,6 @@ TimestamptzType, ) -TEST_DATA_WITH_NULL = { - 'bool': [False, None, True], - 'string': ['a', None, 'z'], - # Go over the 16 bytes to kick in truncation - 'string_long': ['a' * 22, None, 'z' * 22], - 'int': [1, None, 9], - 'long': [1, None, 9], - 'float': [0.0, None, 0.9], - 'double': [0.0, None, 0.9], - 'timestamp': [datetime(2023, 1, 1, 19, 25, 00), None, datetime(2023, 3, 1, 19, 25, 00)], - # 'timestamptz': [datetime(2023, 1, 1, 19, 25, 00), None, datetime(2023, 3, 1, 19, 25, 00)], - 'timestamptz': [ - datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc), - None, - datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc), - ], - 'date': [date(2023, 1, 1), None, date(2023, 3, 1)], - # Not supported by Spark - # 'time': [time(1, 22, 0), None, time(19, 25, 0)], - # Not natively supported by Arrow - # 'uuid': [uuid.UUID('00000000-0000-0000-0000-000000000000').bytes, None, uuid.UUID('11111111-1111-1111-1111-111111111111').bytes], - 'binary': [b'\01', None, b'\22'], - 'fixed': [ - uuid.UUID('00000000-0000-0000-0000-000000000000').bytes, - None, - uuid.UUID('11111111-1111-1111-1111-111111111111').bytes, - ], -} - - TABLE_SCHEMA = Schema( NestedField(field_id=1, name="bool", field_type=BooleanType(), required=False), NestedField(field_id=2, name="string", field_type=StringType(), required=False), From 9c13dbb9b6e7874ad6f266ba795a07b952245513 Mon Sep 17 00:00:00 2001 From: Adrian Qin <147659252+jqin61@users.noreply.github.com> Date: Tue, 2 Apr 2024 14:05:34 +0000 Subject: [PATCH 10/15] license --- tests/integration/test_writes/utils.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tests/integration/test_writes/utils.py b/tests/integration/test_writes/utils.py index b88db8a137..4452095eb8 100644 --- a/tests/integration/test_writes/utils.py +++ b/tests/integration/test_writes/utils.py @@ -1,3 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint:disable=redefined-outer-name from typing import List, Optional import pyarrow as pa From ebbec01f57a77332d9000b8ccc1fb6cb93695836 Mon Sep 17 00:00:00 2001 From: Adrian Qin <147659252+jqin61@users.noreply.github.com> Date: Tue, 2 Apr 2024 14:38:48 +0000 Subject: [PATCH 11/15] save changes for swtiching codespaces --- pyiceberg/table/__init__.py | 17 ++++++----------- tests/table/test_init.py | 6 +++--- 2 files changed, 9 insertions(+), 14 deletions(-) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 28d59da534..57c5b4127a 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -1133,10 +1133,10 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT) if not isinstance(df, pa.Table): raise ValueError(f"Expected PyArrow table, got: {df}") - supported = {IdentityTransform} - if not all(type(field.transform) in supported for field in self.metadata.spec().fields): + supported_transforms = {IdentityTransform} + if not all(type(field.transform) in supported_transforms for field in self.metadata.spec().fields): raise ValueError( - f"All transforms are not supported, expected: {supported}, but get: {[str(field) for field in self.metadata.spec().fields if field.transform not in supported]}." + f"All transforms are not supported, expected: {supported_transforms}, but get: {[str(field) for field in self.metadata.spec().fields if field.transform not in supported_transforms]}." ) _check_schema_compatible(self.schema(), other_schema=df.schema) @@ -2502,11 +2502,6 @@ class WriteTask: sort_order_id: Optional[int] = None partition_key: Optional[PartitionKey] = None - def generate_data_file_partition_path(self) -> str: - if self.partition_key is None: - raise ValueError("Cannot generate partition path based on non-partitioned WriteTask") - return self.partition_key.to_path() - def generate_data_file_filename(self, extension: str) -> str: # Mimics the behavior in the Java API: # https://github.com/apache/iceberg/blob/a582968975dd30ff4917fbbe999f1be903efac02/core/src/main/java/org/apache/iceberg/io/OutputFileFactory.java#L92-L101 @@ -2514,7 +2509,7 @@ def generate_data_file_filename(self, extension: str) -> str: def generate_data_file_path(self, extension: str) -> str: if self.partition_key: - file_path = f"{self.generate_data_file_partition_path()}/{self.generate_data_file_filename(extension)}" + file_path = f"{self.partition_key.to_path()}/{self.generate_data_file_filename(extension)}" return file_path else: return self.generate_data_file_filename(extension) @@ -2559,7 +2554,7 @@ def _dataframe_to_data_files( ) if len(table_metadata.spec().fields) > 0: - partitions = partition(spec=table_metadata.spec(), schema=table_metadata.schema(), arrow_table=df) + partitions = _determine_partitions(spec=table_metadata.spec(), schema=table_metadata.schema(), arrow_table=df) yield from write_file( io=io, table_metadata=table_metadata, @@ -3212,7 +3207,7 @@ def _get_table_partitions( return table_partitions -def partition(spec: PartitionSpec, schema: Schema, arrow_table: pa.Table) -> Iterable[TablePartition]: +def _determine_partitions(spec: PartitionSpec, schema: Schema, arrow_table: pa.Table) -> List[TablePartition]: """Based on the iceberg table partition spec, slice the arrow table into partitions with their keys. Example: diff --git a/tests/table/test_init.py b/tests/table/test_init.py index 15fd7d0b8b..2bc78f3197 100644 --- a/tests/table/test_init.py +++ b/tests/table/test_init.py @@ -64,9 +64,9 @@ UpdateSchema, _apply_table_update, _check_schema_compatible, + _determine_partitions, _match_deletes_to_data_file, _TableMetadataUpdateContext, - partition, update_table_metadata, ) from pyiceberg.table.metadata import INITIAL_SEQUENCE_NUMBER, TableMetadataUtil, TableMetadataV2, _generate_snapshot_id @@ -1166,7 +1166,7 @@ def test_partition_for_demo() -> None: PartitionField(source_id=2, field_id=1002, transform=IdentityTransform(), name="n_legs_identity"), PartitionField(source_id=1, field_id=1001, transform=IdentityTransform(), name="year_identity"), ) - result = partition(partition_spec, test_schema, arrow_table) + result = _determine_partitions(partition_spec, test_schema, arrow_table) assert {table_partition.partition_key.partition for table_partition in result} == { Record(n_legs_identity=2, year_identity=2020), Record(n_legs_identity=100, year_identity=2021), @@ -1216,7 +1216,7 @@ def test_identity_partition_on_multi_columns() -> None: } arrow_table = pa.Table.from_pydict(test_data, schema=test_pa_schema) - result = partition(partition_spec, test_schema, arrow_table) + result = _determine_partitions(partition_spec, test_schema, arrow_table) assert {table_partition.partition_key.partition for table_partition in result} == expected concatenated_arrow_table = pa.concat_tables([table_partition.arrow_table_partition for table_partition in result]) From caddccedf41a47dc623851144453582635fc4e9a Mon Sep 17 00:00:00 2001 From: Adrian Qin <147659252+jqin61@users.noreply.github.com> Date: Tue, 2 Apr 2024 15:21:21 +0000 Subject: [PATCH 12/15] part of the comment fixes --- tests/conftest.py | 6 +++--- tests/integration/test_writes/test_partitioned_writes.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index b62f0f47cd..2c08e29804 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2045,23 +2045,23 @@ def pa_schema() -> "pa.Schema": @pytest.fixture(scope="session") def arrow_table_with_null(pa_schema: "pa.Schema") -> "pa.Table": + """Pyarrow table with all kinds of columns.""" import pyarrow as pa - """PyArrow table with all kinds of columns""" return pa.Table.from_pydict(TEST_DATA_WITH_NULL, schema=pa_schema) @pytest.fixture(scope="session") def arrow_table_without_data(pa_schema: "pa.Schema") -> "pa.Table": + """Pyarrow table without data.""" import pyarrow as pa - """PyArrow table with all kinds of columns.""" return pa.Table.from_pylist([], schema=pa_schema) @pytest.fixture(scope="session") def arrow_table_with_only_nulls(pa_schema: "pa.Schema") -> "pa.Table": + """Pyarrow table with only null values.""" import pyarrow as pa - """PyArrow table with all kinds of columns.""" return pa.Table.from_pylist([{}, {}], schema=pa_schema) diff --git a/tests/integration/test_writes/test_partitioned_writes.py b/tests/integration/test_writes/test_partitioned_writes.py index b5faca5b77..d84b9745a7 100644 --- a/tests/integration/test_writes/test_partitioned_writes.py +++ b/tests/integration/test_writes/test_partitioned_writes.py @@ -172,7 +172,7 @@ def test_query_filter_appended_null_partitioned( assert len(rows) == 6 -@pytest.mark.newyork +@pytest.mark.integration @pytest.mark.parametrize( "part_col", ['int', 'bool', 'string', "string_long", "long", "float", "double", "date", "timestamptz", "timestamp", "binary"] ) From eab2865376ed46f1f99dd106c2f49d6eb5d85204 Mon Sep 17 00:00:00 2001 From: Adrian Qin <147659252+jqin61@users.noreply.github.com> Date: Tue, 2 Apr 2024 16:28:59 +0000 Subject: [PATCH 13/15] fix one type error --- pyiceberg/table/__init__.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 57c5b4127a..108623e09d 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -2543,15 +2543,11 @@ def _dataframe_to_data_files( counter = itertools.count(0) write_uuid = write_uuid or uuid.uuid4() - target_file_size = PropertyUtil.property_as_int( + target_file_size: int = PropertyUtil.property_as_int( # type: ignore # The property is set with non-None value. properties=table_metadata.properties, property_name=TableProperties.WRITE_TARGET_FILE_SIZE_BYTES, default=TableProperties.WRITE_TARGET_FILE_SIZE_BYTES_DEFAULT, ) - if target_file_size is None: - raise ValueError( - "Fail to get neither TableProperties.WRITE_TARGET_FILE_SIZE_BYTES nor WRITE_TARGET_FILE_SIZE_BYTES_DEFAULT for writing target data file." - ) if len(table_metadata.spec().fields) > 0: partitions = _determine_partitions(spec=table_metadata.spec(), schema=table_metadata.schema(), arrow_table=df) From 82dd3adff485c732d24158103ef43cd2e14ded8e Mon Sep 17 00:00:00 2001 From: Adrian Qin <147659252+jqin61@users.noreply.github.com> Date: Tue, 2 Apr 2024 19:40:19 +0000 Subject: [PATCH 14/15] add support for timetype --- pyiceberg/manifest.py | 21 +-------------------- pyiceberg/partitioning.py | 10 ++++++++-- tests/conftest.py | 3 +++ tests/integration/test_writes/utils.py | 1 + 4 files changed, 13 insertions(+), 22 deletions(-) diff --git a/pyiceberg/manifest.py b/pyiceberg/manifest.py index f982629f93..5d20711c94 100644 --- a/pyiceberg/manifest.py +++ b/pyiceberg/manifest.py @@ -19,7 +19,6 @@ import math from abc import ABC, abstractmethod from enum import Enum -from functools import singledispatch from types import TracebackType from typing import ( Any, @@ -41,7 +40,6 @@ from pyiceberg.types import ( BinaryType, BooleanType, - IcebergType, IntegerType, ListType, LongType, @@ -50,7 +48,6 @@ PrimitiveType, StringType, StructType, - TimeType, ) UNASSIGNED_SEQ = -1 @@ -280,28 +277,12 @@ def __repr__(self) -> str: } -@singledispatch -def partition_field_to_data_file_partition_field(partition_field_type: IcebergType) -> PrimitiveType: - raise TypeError(f"Unsupported partition field type: {partition_field_type}") - - -@partition_field_to_data_file_partition_field.register(LongType) -@partition_field_to_data_file_partition_field.register(TimeType) -def _(partition_field_type: PrimitiveType) -> IntegerType: - return IntegerType() - - -@partition_field_to_data_file_partition_field.register(PrimitiveType) -def _(partition_field_type: PrimitiveType) -> PrimitiveType: - return partition_field_type - - def data_file_with_partition(partition_type: StructType, format_version: Literal[1, 2]) -> StructType: data_file_partition_type = StructType(*[ NestedField( field_id=field.field_id, name=field.name, - field_type=partition_field_to_data_file_partition_field(field.field_type), + field_type=field.field_type, required=field.required, ) for field in partition_type.fields diff --git a/pyiceberg/partitioning.py b/pyiceberg/partitioning.py index 16f158828d..a3cf255341 100644 --- a/pyiceberg/partitioning.py +++ b/pyiceberg/partitioning.py @@ -19,7 +19,7 @@ import uuid from abc import ABC, abstractmethod from dataclasses import dataclass -from datetime import date, datetime +from datetime import date, datetime, time from functools import cached_property, singledispatch from typing import ( Any, @@ -62,9 +62,10 @@ StructType, TimestampType, TimestamptzType, + TimeType, UUIDType, ) -from pyiceberg.utils.datetime import date_to_days, datetime_to_micros +from pyiceberg.utils.datetime import date_to_days, datetime_to_micros, time_to_micros INITIAL_PARTITION_SPEC_ID = 0 PARTITION_FIELD_ID_START: int = 1000 @@ -431,6 +432,11 @@ def _(type: IcebergType, value: Optional[date]) -> Optional[int]: return date_to_days(value) if value is not None else None +@_to_partition_representation.register(TimeType) +def _(type: IcebergType, value: Optional[time]) -> Optional[int]: + return time_to_micros(value) if value is not None else None + + @_to_partition_representation.register(UUIDType) def _(type: IcebergType, value: Optional[uuid.UUID]) -> Optional[str]: return str(value) if value is not None else None diff --git a/tests/conftest.py b/tests/conftest.py index 2c08e29804..4a820fedec 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1999,6 +1999,7 @@ def spark() -> "SparkSession": 'long': [1, None, 9], 'float': [0.0, None, 0.9], 'double': [0.0, None, 0.9], + # 'time': [1_000_000, None, 3_000_000], # Example times: 1s, none, and 3s past midnight #Spark does not support time fields 'timestamp': [datetime(2023, 1, 1, 19, 25, 00), None, datetime(2023, 3, 1, 19, 25, 00)], 'timestamptz': [ datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc), @@ -2031,6 +2032,8 @@ def pa_schema() -> "pa.Schema": ("long", pa.int64()), ("float", pa.float32()), ("double", pa.float64()), + # Not supported by Spark + # ("time", pa.time64('us')), ("timestamp", pa.timestamp(unit="us")), ("timestamptz", pa.timestamp(unit="us", tz="UTC")), ("date", pa.date32()), diff --git a/tests/integration/test_writes/utils.py b/tests/integration/test_writes/utils.py index 4452095eb8..792e25185d 100644 --- a/tests/integration/test_writes/utils.py +++ b/tests/integration/test_writes/utils.py @@ -48,6 +48,7 @@ NestedField(field_id=5, name="long", field_type=LongType(), required=False), NestedField(field_id=6, name="float", field_type=FloatType(), required=False), NestedField(field_id=7, name="double", field_type=DoubleType(), required=False), + # NestedField(field_id=8, name="time", field_type=TimeType(), required=False), # Spark does not support time fields NestedField(field_id=8, name="timestamp", field_type=TimestampType(), required=False), NestedField(field_id=9, name="timestamptz", field_type=TimestamptzType(), required=False), NestedField(field_id=10, name="date", field_type=DateType(), required=False), From f786ef4fbb75f26323cc714d0af589ad7f2733c9 Mon Sep 17 00:00:00 2001 From: Adrian Qin <147659252+jqin61@users.noreply.github.com> Date: Thu, 4 Apr 2024 21:21:34 +0000 Subject: [PATCH 15/15] small fix for type hint --- pyiceberg/manifest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyiceberg/manifest.py b/pyiceberg/manifest.py index e282da1204..3b8138b61a 100644 --- a/pyiceberg/manifest.py +++ b/pyiceberg/manifest.py @@ -277,7 +277,7 @@ def __repr__(self) -> str: } -def data_file_with_partition(partition_type: StructType, format_version: Literal[1, 2]) -> StructType: +def data_file_with_partition(partition_type: StructType, format_version: TableVersion) -> StructType: data_file_partition_type = StructType(*[ NestedField( field_id=field.field_id,