Skip to content

Commit

Permalink
Check if schema is compatible in add_files API (apache#907)
Browse files Browse the repository at this point in the history
Co-authored-by: Fokko Driesprong <[email protected]>
  • Loading branch information
sungwy and Fokko authored Jul 12, 2024
1 parent aceed2a commit dceedfa
Show file tree
Hide file tree
Showing 5 changed files with 211 additions and 164 deletions.
45 changes: 45 additions & 0 deletions pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -2032,6 +2032,49 @@ def bin_pack_arrow_table(tbl: pa.Table, target_file_size: int) -> Iterator[List[
return bin_packed_record_batches


def _check_schema_compatible(table_schema: Schema, other_schema: pa.Schema, downcast_ns_timestamp_to_us: bool = False) -> None:
"""
Check if the `table_schema` is compatible with `other_schema`.
Two schemas are considered compatible when they are equal in terms of the Iceberg Schema type.
Raises:
ValueError: If the schemas are not compatible.
"""
name_mapping = table_schema.name_mapping
try:
task_schema = pyarrow_to_schema(
other_schema, name_mapping=name_mapping, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us
)
except ValueError as e:
other_schema = _pyarrow_to_schema_without_ids(other_schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us)
additional_names = set(other_schema.column_names) - set(table_schema.column_names)
raise ValueError(
f"PyArrow table contains more columns: {', '.join(sorted(additional_names))}. Update the schema first (hint, use union_by_name)."
) from e

if table_schema.as_struct() != task_schema.as_struct():
from rich.console import Console
from rich.table import Table as RichTable

console = Console(record=True)

rich_table = RichTable(show_header=True, header_style="bold")
rich_table.add_column("")
rich_table.add_column("Table field")
rich_table.add_column("Dataframe field")

for lhs in table_schema.fields:
try:
rhs = task_schema.find_field(lhs.field_id)
rich_table.add_row("✅" if lhs == rhs else "❌", str(lhs), str(rhs))
except ValueError:
rich_table.add_row("❌", str(lhs), "Missing")

console.print(rich_table)
raise ValueError(f"Mismatch in fields:\n{console.export_text()}")


def parquet_files_to_data_files(io: FileIO, table_metadata: TableMetadata, file_paths: Iterator[str]) -> Iterator[DataFile]:
for file_path in file_paths:
input_file = io.new_input(file_path)
Expand All @@ -2043,6 +2086,8 @@ def parquet_files_to_data_files(io: FileIO, table_metadata: TableMetadata, file_
f"Cannot add file {file_path} because it has field IDs. `add_files` only supports addition of files without field_ids"
)
schema = table_metadata.schema()
_check_schema_compatible(schema, parquet_metadata.schema.to_arrow_schema())

statistics = data_file_statistics_from_parquet_metadata(
parquet_metadata=parquet_metadata,
stats_columns=compute_statistics_plan(schema, table_metadata.properties),
Expand Down
62 changes: 10 additions & 52 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
manifest_evaluator,
)
from pyiceberg.io import FileIO, OutputFile, load_file_io
from pyiceberg.io.pyarrow import _dataframe_to_data_files, expression_to_pyarrow, project_table
from pyiceberg.io.pyarrow import _check_schema_compatible, _dataframe_to_data_files, expression_to_pyarrow, project_table
from pyiceberg.manifest import (
POSITIONAL_DELETE_SCHEMA,
DataFile,
Expand Down Expand Up @@ -166,54 +166,8 @@

ALWAYS_TRUE = AlwaysTrue()
TABLE_ROOT_ID = -1
DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE = "downcast-ns-timestamp-to-us-on-write"
_JAVA_LONG_MAX = 9223372036854775807


def _check_schema_compatible(table_schema: Schema, other_schema: "pa.Schema") -> None:
"""
Check if the `table_schema` is compatible with `other_schema`.
Two schemas are considered compatible when they are equal in terms of the Iceberg Schema type.
Raises:
ValueError: If the schemas are not compatible.
"""
from pyiceberg.io.pyarrow import _pyarrow_to_schema_without_ids, pyarrow_to_schema

downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False
name_mapping = table_schema.name_mapping
try:
task_schema = pyarrow_to_schema(
other_schema, name_mapping=name_mapping, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us
)
except ValueError as e:
other_schema = _pyarrow_to_schema_without_ids(other_schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us)
additional_names = set(other_schema.column_names) - set(table_schema.column_names)
raise ValueError(
f"PyArrow table contains more columns: {', '.join(sorted(additional_names))}. Update the schema first (hint, use union_by_name)."
) from e

if table_schema.as_struct() != task_schema.as_struct():
from rich.console import Console
from rich.table import Table as RichTable

console = Console(record=True)

rich_table = RichTable(show_header=True, header_style="bold")
rich_table.add_column("")
rich_table.add_column("Table field")
rich_table.add_column("Dataframe field")

for lhs in table_schema.fields:
try:
rhs = task_schema.find_field(lhs.field_id)
rich_table.add_row("✅" if lhs == rhs else "❌", str(lhs), str(rhs))
except ValueError:
rich_table.add_row("❌", str(lhs), "Missing")

console.print(rich_table)
raise ValueError(f"Mismatch in fields:\n{console.export_text()}")
DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE = "downcast-ns-timestamp-to-us-on-write"


class TableProperties:
Expand Down Expand Up @@ -526,8 +480,10 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT)
raise ValueError(
f"Not all partition types are supported for writes. Following partitions cannot be written using pyarrow: {unsupported_partitions}."
)

_check_schema_compatible(self._table.schema(), other_schema=df.schema)
downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False
_check_schema_compatible(
self._table.schema(), other_schema=df.schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us
)
# cast if the two schemas are compatible but not equal
table_arrow_schema = self._table.schema().as_arrow()
if table_arrow_schema != df.schema:
Expand Down Expand Up @@ -585,8 +541,10 @@ def overwrite(
raise ValueError(
f"Not all partition types are supported for writes. Following partitions cannot be written using pyarrow: {unsupported_partitions}."
)

_check_schema_compatible(self._table.schema(), other_schema=df.schema)
downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False
_check_schema_compatible(
self._table.schema(), other_schema=df.schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us
)
# cast if the two schemas are compatible but not equal
table_arrow_schema = self._table.schema().as_arrow()
if table_arrow_schema != df.schema:
Expand Down
85 changes: 65 additions & 20 deletions tests/integration/test_add_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# pylint:disable=redefined-outer-name

import os
import re
from datetime import date
from typing import Iterator

Expand Down Expand Up @@ -463,6 +464,57 @@ def test_add_files_snapshot_properties(spark: SparkSession, session_catalog: Cat
assert summary["snapshot_prop_a"] == "test_prop_a"


@pytest.mark.integration
def test_add_files_fails_on_schema_mismatch(spark: SparkSession, session_catalog: Catalog, format_version: int) -> None:
identifier = f"default.table_schema_mismatch_fails_v{format_version}"

tbl = _create_table(session_catalog, identifier, format_version)
WRONG_SCHEMA = pa.schema([
("foo", pa.bool_()),
("bar", pa.string()),
("baz", pa.string()), # should be integer
("qux", pa.date32()),
])
file_path = f"s3://warehouse/default/table_schema_mismatch_fails/v{format_version}/test.parquet"
# write parquet files
fo = tbl.io.new_output(file_path)
with fo.create(overwrite=True) as fos:
with pq.ParquetWriter(fos, schema=WRONG_SCHEMA) as writer:
writer.write_table(
pa.Table.from_pylist(
[
{
"foo": True,
"bar": "bar_string",
"baz": "123",
"qux": date(2024, 3, 7),
},
{
"foo": True,
"bar": "bar_string",
"baz": "124",
"qux": date(2024, 3, 7),
},
],
schema=WRONG_SCHEMA,
)
)

expected = """Mismatch in fields:
┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ ┃ Table field ┃ Dataframe field ┃
┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ ✅ │ 1: foo: optional boolean │ 1: foo: optional boolean │
| ✅ │ 2: bar: optional string │ 2: bar: optional string │
│ ❌ │ 3: baz: optional int │ 3: baz: optional string │
│ ✅ │ 4: qux: optional date │ 4: qux: optional date │
└────┴──────────────────────────┴──────────────────────────┘
"""

with pytest.raises(ValueError, match=expected):
tbl.add_files(file_paths=[file_path])


@pytest.mark.integration
def test_add_files_with_large_and_regular_schema(spark: SparkSession, session_catalog: Catalog, format_version: int) -> None:
identifier = f"default.unpartitioned_with_large_types{format_version}"
Expand Down Expand Up @@ -518,7 +570,7 @@ def test_add_files_with_large_and_regular_schema(spark: SparkSession, session_ca
assert table_schema == arrow_schema_large


def test_timestamp_tz_ns_downcast_on_read(session_catalog: Catalog, format_version: int, mocker: MockerFixture) -> None:
def test_add_files_with_timestamp_tz_ns_fails(session_catalog: Catalog, format_version: int, mocker: MockerFixture) -> None:
nanoseconds_schema_iceberg = Schema(NestedField(1, "quux", TimestamptzType()))

nanoseconds_schema = pa.schema([
Expand Down Expand Up @@ -549,25 +601,18 @@ def test_timestamp_tz_ns_downcast_on_read(session_catalog: Catalog, format_versi
partition_spec=PartitionSpec(),
)

file_paths = [f"s3://warehouse/default/test_timestamp_tz/v{format_version}/test-{i}.parquet" for i in range(5)]
file_path = f"s3://warehouse/default/test_timestamp_tz/v{format_version}/test.parquet"
# write parquet files
for file_path in file_paths:
fo = tbl.io.new_output(file_path)
with fo.create(overwrite=True) as fos:
with pq.ParquetWriter(fos, schema=nanoseconds_schema) as writer:
writer.write_table(arrow_table)
fo = tbl.io.new_output(file_path)
with fo.create(overwrite=True) as fos:
with pq.ParquetWriter(fos, schema=nanoseconds_schema) as writer:
writer.write_table(arrow_table)

# add the parquet files as data files
tbl.add_files(file_paths=file_paths)

assert tbl.scan().to_arrow() == pa.concat_tables(
[
arrow_table.cast(
pa.schema([
("quux", pa.timestamp("us", tz="UTC")),
]),
safe=False,
)
]
* 5
)
with pytest.raises(
TypeError,
match=re.escape(
"Iceberg does not yet support 'ns' timestamp precision. Use 'downcast-ns-timestamp-to-us-on-write' configuration property to automatically downcast 'ns' to 'us' on write."
),
):
tbl.add_files(file_paths=[file_path])
91 changes: 91 additions & 0 deletions tests/io/test_pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
PyArrowFile,
PyArrowFileIO,
StatsAggregator,
_check_schema_compatible,
_ConvertToArrowSchema,
_determine_partitions,
_primitive_to_physical,
Expand Down Expand Up @@ -1722,6 +1723,96 @@ def test_bin_pack_arrow_table(arrow_table_with_null: pa.Table) -> None:
assert len(list(bin_packed)) == 5


def test_schema_mismatch_type(table_schema_simple: Schema) -> None:
other_schema = pa.schema((
pa.field("foo", pa.string(), nullable=True),
pa.field("bar", pa.decimal128(18, 6), nullable=False),
pa.field("baz", pa.bool_(), nullable=True),
))

expected = r"""Mismatch in fields:
┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ ┃ Table field ┃ Dataframe field ┃
┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ ✅ │ 1: foo: optional string │ 1: foo: optional string │
│ ❌ │ 2: bar: required int │ 2: bar: required decimal\(18, 6\) │
│ ✅ │ 3: baz: optional boolean │ 3: baz: optional boolean │
└────┴──────────────────────────┴─────────────────────────────────┘
"""

with pytest.raises(ValueError, match=expected):
_check_schema_compatible(table_schema_simple, other_schema)


def test_schema_mismatch_nullability(table_schema_simple: Schema) -> None:
other_schema = pa.schema((
pa.field("foo", pa.string(), nullable=True),
pa.field("bar", pa.int32(), nullable=True),
pa.field("baz", pa.bool_(), nullable=True),
))

expected = """Mismatch in fields:
┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ ┃ Table field ┃ Dataframe field ┃
┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ ✅ │ 1: foo: optional string │ 1: foo: optional string │
│ ❌ │ 2: bar: required int │ 2: bar: optional int │
│ ✅ │ 3: baz: optional boolean │ 3: baz: optional boolean │
└────┴──────────────────────────┴──────────────────────────┘
"""

with pytest.raises(ValueError, match=expected):
_check_schema_compatible(table_schema_simple, other_schema)


def test_schema_mismatch_missing_field(table_schema_simple: Schema) -> None:
other_schema = pa.schema((
pa.field("foo", pa.string(), nullable=True),
pa.field("baz", pa.bool_(), nullable=True),
))

expected = """Mismatch in fields:
┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ ┃ Table field ┃ Dataframe field ┃
┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ ✅ │ 1: foo: optional string │ 1: foo: optional string │
│ ❌ │ 2: bar: required int │ Missing │
│ ✅ │ 3: baz: optional boolean │ 3: baz: optional boolean │
└────┴──────────────────────────┴──────────────────────────┘
"""

with pytest.raises(ValueError, match=expected):
_check_schema_compatible(table_schema_simple, other_schema)


def test_schema_mismatch_additional_field(table_schema_simple: Schema) -> None:
other_schema = pa.schema((
pa.field("foo", pa.string(), nullable=True),
pa.field("bar", pa.int32(), nullable=True),
pa.field("baz", pa.bool_(), nullable=True),
pa.field("new_field", pa.date32(), nullable=True),
))

expected = r"PyArrow table contains more columns: new_field. Update the schema first \(hint, use union_by_name\)."

with pytest.raises(ValueError, match=expected):
_check_schema_compatible(table_schema_simple, other_schema)


def test_schema_downcast(table_schema_simple: Schema) -> None:
# large_string type is compatible with string type
other_schema = pa.schema((
pa.field("foo", pa.large_string(), nullable=True),
pa.field("bar", pa.int32(), nullable=False),
pa.field("baz", pa.bool_(), nullable=True),
))

try:
_check_schema_compatible(table_schema_simple, other_schema)
except Exception:
pytest.fail("Unexpected Exception raised when calling `_check_schema`")


def test_partition_for_demo() -> None:
test_pa_schema = pa.schema([("year", pa.int64()), ("n_legs", pa.int64()), ("animal", pa.string())])
test_schema = Schema(
Expand Down
Loading

0 comments on commit dceedfa

Please sign in to comment.