Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dev/spark/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ ARG BASE_IMAGE_SPARK_VERSION=4.0.1
FROM apache/spark:${BASE_IMAGE_SPARK_VERSION}

# Dependency versions - keep these compatible
ARG ICEBERG_VERSION=1.10.0
ARG ICEBERG_VERSION=1.10.1
ARG ICEBERG_SPARK_RUNTIME_VERSION=4.0_2.13
ARG SPARK_VERSION=4.0.1
ARG HADOOP_VERSION=3.4.1
Expand Down
6 changes: 5 additions & 1 deletion pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -789,7 +789,11 @@ def visit_string(self, _: StringType) -> pa.DataType:
return pa.large_string()

def visit_uuid(self, _: UUIDType) -> pa.DataType:
return pa.uuid()
# TODO: Change to uuid when PyArrow implements filtering for UUID types
# Using binary(16) instead of pa.uuid() because filtering is not
# implemented for UUID types in PyArrow
# (context: https://github.com/apache/iceberg-python/issues/2372)
return pa.binary(16)

def visit_unknown(self, _: UnknownType) -> pa.DataType:
"""Type `UnknownType` can be promoted to any primitive type in V3+ tables per the Iceberg spec."""
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/test_add_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -828,7 +828,7 @@ def test_add_files_with_valid_upcast(
pa.field("list", pa.list_(pa.int64()), nullable=False),
pa.field("map", pa.map_(pa.string(), pa.int64()), nullable=False),
pa.field("double", pa.float64(), nullable=True),
pa.field("uuid", pa.uuid(), nullable=True),
pa.field("uuid", pa.binary(16), nullable=True),
)
)
)
Expand Down
8 changes: 4 additions & 4 deletions tests/integration/test_reads.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,15 +610,15 @@ def test_partitioned_tables(catalog: Catalog) -> None:
def test_unpartitioned_uuid_table(catalog: Catalog) -> None:
unpartitioned_uuid = catalog.load_table("default.test_uuid_and_fixed_unpartitioned")
arrow_table_eq = unpartitioned_uuid.scan(row_filter="uuid_col == '102cb62f-e6f8-4eb0-9973-d9b012ff0967'").to_arrow()
assert arrow_table_eq["uuid_col"].to_pylist() == [uuid.UUID("102cb62f-e6f8-4eb0-9973-d9b012ff0967")]
assert arrow_table_eq["uuid_col"].to_pylist() == [uuid.UUID("102cb62f-e6f8-4eb0-9973-d9b012ff0967").bytes]

arrow_table_neq = unpartitioned_uuid.scan(
row_filter="uuid_col != '102cb62f-e6f8-4eb0-9973-d9b012ff0967' and uuid_col != '639cccce-c9d2-494a-a78c-278ab234f024'"
).to_arrow()
assert arrow_table_neq["uuid_col"].to_pylist() == [
uuid.UUID("ec33e4b2-a834-4cc3-8c4a-a1d3bfc2f226"),
uuid.UUID("c1b0d8e0-0b0e-4b1e-9b0a-0e0b0d0c0a0b"),
uuid.UUID("923dae77-83d6-47cd-b4b0-d383e64ee57e"),
uuid.UUID("ec33e4b2-a834-4cc3-8c4a-a1d3bfc2f226").bytes,
uuid.UUID("c1b0d8e0-0b0e-4b1e-9b0a-0e0b0d0c0a0b").bytes,
uuid.UUID("923dae77-83d6-47cd-b4b0-d383e64ee57e").bytes,
]


Expand Down
64 changes: 62 additions & 2 deletions tests/integration/test_writes/test_writes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1500,7 +1500,7 @@ def test_table_write_schema_with_valid_upcast(
pa.field("list", pa.list_(pa.int64()), nullable=False),
pa.field("map", pa.map_(pa.string(), pa.int64()), nullable=False),
pa.field("double", pa.float64(), nullable=True), # can support upcasting float to double
pa.field("uuid", pa.uuid(), nullable=True),
pa.field("uuid", pa.binary(16), nullable=True),
)
)
)
Expand Down Expand Up @@ -2138,7 +2138,7 @@ def test_uuid_partitioning(session_catalog: Catalog, spark: SparkSession, transf
tbl.append(arr_table)

lhs = [r[0] for r in spark.table(identifier).collect()]
rhs = [str(u.as_py()) for u in tbl.scan().to_arrow()["uuid"].combine_chunks()]
rhs = [str(uuid.UUID(bytes=u.as_py())) for u in tbl.scan().to_arrow()["uuid"].combine_chunks()]
assert lhs == rhs


Expand Down Expand Up @@ -2530,3 +2530,63 @@ def test_v3_write_and_read_row_lineage(spark: SparkSession, session_catalog: Cat
assert tbl.metadata.next_row_id == initial_next_row_id + len(test_data), (
"Expected next_row_id to be incremented by the number of added rows"
)


@pytest.mark.integration
def test_write_uuid_in_pyiceberg_and_scan(session_catalog: Catalog, spark: SparkSession) -> None:
"""Test UUID compatibility between PyIceberg and Spark.

UUIDs must be written as binary(16) for Spark compatibility since Java Arrow
metadata differs from Python Arrow metadata for UUID types.
"""
identifier = "default.test_write_uuid_in_pyiceberg_and_scan"

catalog = load_catalog("default", type="in-memory")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: do we need these lines with the session catlaog

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I fixed it in 1751f0c

catalog.create_namespace("ns")

schema = Schema(NestedField(field_id=1, name="uuid_col", field_type=UUIDType(), required=False))

test_data_with_null = {
"uuid_col": [
uuid.UUID("00000000-0000-0000-0000-000000000000").bytes,
None,
uuid.UUID("11111111-1111-1111-1111-111111111111").bytes,
]
}

try:
session_catalog.drop_table(identifier=identifier)
except NoSuchTableError:
pass

table = _create_table(session_catalog, identifier, {"format-version": "2"}, schema=schema)

arrow_table = pa.table(test_data_with_null, schema=schema.as_arrow())

# Write with pyarrow
table.append(arrow_table)

# Write with pyspark
spark.sql(
f"""
INSERT INTO {identifier} VALUES ("22222222-2222-2222-2222-222222222222")
"""
)
df = spark.table(identifier)

table.refresh()

assert df.count() == 4
assert len(table.scan().to_arrow()) == 4

result = df.where("uuid_col = '00000000-0000-0000-0000-000000000000'")
assert result.count() == 1

result = df.where("uuid_col = '22222222-2222-2222-2222-222222222222'")
assert result.count() == 1

result = table.scan(row_filter=EqualTo("uuid_col", uuid.UUID("00000000-0000-0000-0000-000000000000").bytes)).to_arrow()
assert len(result) == 1

result = table.scan(row_filter=EqualTo("uuid_col", uuid.UUID("22222222-2222-2222-2222-222222222222").bytes)).to_arrow()
assert len(result) == 1