Skip to content

Commit

Permalink
fix avro writer writing null; add nullable int column partition test
Browse files Browse the repository at this point in the history
  • Loading branch information
jqin61 committed Feb 4, 2024
1 parent fcd94fd commit b74521e
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 126 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ test-integration:
docker-compose -f dev/docker-compose-integration.yml up -d
sleep 5
docker-compose -f dev/docker-compose-integration.yml exec -T spark-iceberg ipython ./provision.py
poetry run pytest tests/ -v -m newyork ${PYTEST_ARGS} -s
poetry run pytest tests/ -v -m integration ${PYTEST_ARGS} -s

test-integration-rebuild:
docker-compose -f dev/docker-compose-integration.yml kill
Expand Down
1 change: 1 addition & 0 deletions pyiceberg/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,7 @@ def data_file_with_partition(partition_type: StructType, format_version: Literal
field_id=field.field_id,
name=field.name,
field_type=partition_field_to_data_file_partition_field(field.field_type),
required=False
)
for field in partition_type.fields
])
Expand Down
190 changes: 65 additions & 125 deletions tests/integration/test_partitioned_writes.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,29 +88,6 @@ def catalog() -> Catalog:
],
}

TEST_DATA_WITHOUT_NULL = {
'bool': [False, True, True],
'string': ['a', 'z', 'z'],
# Go over the 16 bytes to kick in truncation
'string_long': ['a' * 22, 'a' * 22, 'z' * 22],
'int': [1, 1, 9],
'long': [1, 1, 9],
'float': [0.0, 0.0, 0.9],
'double': [0.0, 0.0, 0.9],
'timestamp': [datetime(2023, 1, 1, 19, 25, 00), datetime(2023, 1, 1, 19, 25, 00), datetime(2023, 3, 1, 19, 25, 00)],
'timestamptz': [datetime(2023, 1, 1, 19, 25, 00), datetime(2023, 1, 1, 19, 25, 00), datetime(2023, 3, 1, 19, 25, 00)],
'date': [date(2023, 1, 1), date(2023, 1, 1), date(2023, 3, 1)],
# Not supported by Spark
# 'time': [time(1, 22, 0), None, time(19, 25, 0)],
# Not natively supported by Arrow
# 'uuid': [uuid.UUID('00000000-0000-0000-0000-000000000000').bytes, None, uuid.UUID('11111111-1111-1111-1111-111111111111').bytes],
'binary': [b'\01', b'\01', b'\22'],
'fixed': [
uuid.UUID('00000000-0000-0000-0000-000000000000').bytes,
uuid.UUID('00000000-0000-0000-0000-000000000000').bytes,
uuid.UUID('11111111-1111-1111-1111-111111111111').bytes,
],
}

TABLE_SCHEMA = Schema(
NestedField(field_id=1, name="bool", field_type=BooleanType(), required=False),
Expand Down Expand Up @@ -144,30 +121,6 @@ def session_catalog() -> Catalog:
)


@pytest.fixture(scope="session")
def arrow_table_without_null() -> pa.Table:
"""PyArrow table with all kinds of columns"""
pa_schema = pa.schema([
("bool", pa.bool_()),
("string", pa.string()),
("string_long", pa.string()),
("int", pa.int32()),
("long", pa.int64()),
("float", pa.float32()),
("double", pa.float64()),
("timestamp", pa.timestamp(unit="us")),
("timestamptz", pa.timestamp(unit="us", tz="UTC")),
("date", pa.date32()),
# Not supported by Spark
# ("time", pa.time64("us")),
# Not natively supported by Arrow
# ("uuid", pa.fixed(16)),
("binary", pa.binary()),
("fixed", pa.binary(16)),
])
return pa.Table.from_pydict(TEST_DATA_WITHOUT_NULL, schema=pa_schema)


@pytest.fixture(scope="session")
def arrow_table_with_null() -> pa.Table:
"""PyArrow table with all kinds of columns"""
Expand All @@ -191,10 +144,10 @@ def arrow_table_with_null() -> pa.Table:
])
return pa.Table.from_pydict(TEST_DATA_WITH_NULL, schema=pa_schema)

# stub

@pytest.fixture(scope="session", autouse=True)
def table_v1_without_null_partitioned(session_catalog: Catalog, arrow_table_without_null: pa.Table) -> None:
identifier = "default.arrow_table_v1_without_null_partitioned"
def table_v1_with_null_partitioned(session_catalog: Catalog, arrow_table_with_null: pa.Table, request) -> None:
identifier = "default.arrow_table_v1_with_null_partitioned"

try:
session_catalog.drop_table(identifier=identifier)
Expand All @@ -207,33 +160,14 @@ def table_v1_without_null_partitioned(session_catalog: Catalog, arrow_table_with
partition_spec=PartitionSpec(PartitionField(source_id=4, field_id=1001, transform=IdentityTransform(), name="int")),
properties={'format-version': '1'},
)
tbl.append(arrow_table_without_null)
tbl.append(arrow_table_with_null)

assert tbl.format_version == 1, f"Expected v1, got: v{tbl.format_version}"

# # for above
# @pytest.fixture(scope="session", autouse=True)
# def table_v1_with_null_partitioned(session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None:
# identifier = "default.arrow_table_v1_without_null_partitioned"

# try:
# session_catalog.drop_table(identifier=identifier)
# except NoSuchTableError:
# pass

# tbl = session_catalog.create_table(
# identifier=identifier,
# schema=TABLE_SCHEMA,
# partition_spec=PartitionSpec(PartitionField(source_id=4, field_id=1001, transform=IdentityTransform(), name="int")),
# properties={'format-version': '1'},
# )
# tbl.append(arrow_table_with_null)

# assert tbl.format_version == 1, f"Expected v1, got: v{tbl.format_version}"

@pytest.fixture(scope="session", autouse=True)
def table_v1_appended_without_null_partitioned(session_catalog: Catalog, arrow_table_without_null: pa.Table) -> None:
identifier = "default.arrow_table_v1_appended_without_null_partitioned"
def table_v1_appended_with_null_partitioned(session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None:
identifier = "default.arrow_table_v1_appended_with_null_partitioned"

try:
session_catalog.drop_table(identifier=identifier)
Expand All @@ -243,14 +177,14 @@ def table_v1_appended_without_null_partitioned(session_catalog: Catalog, arrow_t
tbl = session_catalog.create_table(identifier=identifier, schema=TABLE_SCHEMA, properties={'format-version': '1'})

for _ in range(2):
tbl.append(arrow_table_without_null)
tbl.append(arrow_table_with_null)

assert tbl.format_version == 1, f"Expected v1, got: v{tbl.format_version}"


@pytest.fixture(scope="session", autouse=True)
def table_v2_without_null_partitioned(session_catalog: Catalog, arrow_table_without_null: pa.Table) -> None:
identifier = "default.arrow_table_v2_without_null_partitioned"
def table_v2_with_null_partitioned(session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None:
identifier = "default.arrow_table_v2_with_null_partitioned"

try:
session_catalog.drop_table(identifier=identifier)
Expand All @@ -263,14 +197,14 @@ def table_v2_without_null_partitioned(session_catalog: Catalog, arrow_table_with
partition_spec=PartitionSpec(PartitionField(source_id=4, field_id=1001, transform=IdentityTransform(), name="int")),
properties={'format-version': '2'},
)
tbl.append(arrow_table_without_null)
tbl.append(arrow_table_with_null)

assert tbl.format_version == 2, f"Expected v2, got: v{tbl.format_version}"


@pytest.fixture(scope="session", autouse=True)
def table_v2_appended_without_null_partitioned(session_catalog: Catalog, arrow_table_without_null: pa.Table) -> None:
identifier = "default.arrow_table_v2_appended_without_null_partitioned"
def table_v2_appended_with_null_partitioned(session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None:
identifier = "default.arrow_table_v2_appended_with_null_partitioned"

try:
session_catalog.drop_table(identifier=identifier)
Expand All @@ -280,32 +214,14 @@ def table_v2_appended_without_null_partitioned(session_catalog: Catalog, arrow_t
tbl = session_catalog.create_table(identifier=identifier, schema=TABLE_SCHEMA, properties={'format-version': '2'})

for _ in range(2):
tbl.append(arrow_table_without_null)
tbl.append(arrow_table_with_null)

assert tbl.format_version == 2, f"Expected v1, got: v{tbl.format_version}"


@pytest.mark.newyork
@pytest.mark.parametrize("col", TEST_DATA_WITHOUT_NULL.keys())
@pytest.mark.parametrize("format_version", [1, 2])
def test_query_filter_null(spark: SparkSession, col: str, format_version: int) -> None:
identifier = f"default.arrow_table_v{format_version}_without_null_partitioned"
df = spark.table(identifier)
assert df.where(f"{col} is not null").count() == 3, f"Expected 3 rows for {col}"


@pytest.mark.adrian
@pytest.mark.parametrize("col", TEST_DATA_WITHOUT_NULL.keys())
@pytest.mark.parametrize("format_version", [1, 2])
def test_query_filter_appended_null_partitioned(spark: SparkSession, col: str, format_version: int) -> None:
identifier = f"default.arrow_table_v{format_version}_appended_without_null_partitioned"
df = spark.table(identifier)
assert df.where(f"{col} is not null").count() == 6, f"Expected 6 rows for {col}"


@pytest.fixture(scope="session", autouse=True)
def table_v1_v2_appended_without_null(session_catalog: Catalog, arrow_table_without_null: pa.Table) -> None:
identifier = "default.arrow_table_v1_v2_appended_without_null"
def table_v1_v2_appended_with_null(session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None:
identifier = "default.arrow_table_v1_v2_appended_with_null"

try:
session_catalog.drop_table(identifier=identifier)
Expand All @@ -318,17 +234,38 @@ def table_v1_v2_appended_without_null(session_catalog: Catalog, arrow_table_with
partition_spec=PartitionSpec(PartitionField(source_id=4, field_id=1001, transform=IdentityTransform(), name="int")),
properties={'format-version': '1'},
)
tbl.append(arrow_table_without_null)
tbl.append(arrow_table_with_null)

assert tbl.format_version == 1, f"Expected v1, got: v{tbl.format_version}"

with tbl.transaction() as tx:
tx.upgrade_table_version(format_version=2)

tbl.append(arrow_table_without_null)
tbl.append(arrow_table_with_null)

assert tbl.format_version == 2, f"Expected v2, got: v{tbl.format_version}"

# todo parametrize partition for each of the columns
@pytest.mark.integration
@pytest.mark.parametrize("col", TEST_DATA_WITH_NULL.keys())
@pytest.mark.parametrize("format_version", [1, 2])
def test_query_filter_null_partitioned(spark: SparkSession, col: str, format_version: int, ) -> None:
identifier = f"default.arrow_table_v{format_version}_with_null_partitioned"
df = spark.table(identifier)
assert df.where(f"{col} is not null").count() == 2, f"Expected 2 rows for {col}"


# todo parametrize partition for each of the columns
@pytest.mark.integeration
@pytest.mark.parametrize("col", TEST_DATA_WITH_NULL.keys())
@pytest.mark.parametrize("format_version", [1, 2])
def test_query_filter_appended_null_partitioned(spark: SparkSession, col: str, format_version: int) -> None:
identifier = f"default.arrow_table_v{format_version}_appended_with_null_partitioned"
df = spark.table(identifier)
assert df.where(f"{col} is not null").count() == 4, f"Expected 6 rows for {col}"




@pytest.fixture(scope="session")
def spark() -> SparkSession:
Expand Down Expand Up @@ -363,17 +300,18 @@ def spark() -> SparkSession:

return spark


@pytest.mark.adrian
@pytest.mark.parametrize("col", TEST_DATA_WITHOUT_NULL.keys())
# todo parametrize partition for each of the columns
@pytest.mark.integeration
@pytest.mark.parametrize("col", TEST_DATA_WITH_NULL.keys())
def test_query_filter_v1_v2_append_null(spark: SparkSession, col: str) -> None:
identifier = "default.arrow_table_v1_v2_appended_without_null"
identifier = "default.arrow_table_v1_v2_appended_with_null"
df = spark.table(identifier)
assert df.where(f"{col} is not null").count() == 6, f"Expected 3 row for {col}"
assert df.where(f"{col} is not null").count() == 4, f"Expected 4 row for {col}"



@pytest.mark.adrian
def test_summaries(spark: SparkSession, session_catalog: Catalog, arrow_table_without_null: pa.Table) -> None:
@pytest.mark.integeration
def test_summaries_with_null(spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None:
identifier = "default.arrow_table_summaries"

try:
Expand All @@ -387,8 +325,8 @@ def test_summaries(spark: SparkSession, session_catalog: Catalog, arrow_table_wi
properties={'format-version': '2'},
)

tbl.append(arrow_table_without_null)
tbl.append(arrow_table_without_null)
tbl.append(arrow_table_with_null)
tbl.append(arrow_table_with_null)

rows = spark.sql(
f"""
Expand All @@ -404,32 +342,33 @@ def test_summaries(spark: SparkSession, session_catalog: Catalog, arrow_table_wi
summaries = [row.summary for row in rows]

assert summaries[0] == {
'added-data-files': '2',
'added-files-size': '10433',
'added-data-files': '3',
'added-files-size': '14471',
'added-records': '3',
'total-data-files': '2',
'total-data-files': '3',
'total-delete-files': '0',
'total-equality-deletes': '0',
'total-files-size': '10433',
'total-files-size': '14471',
'total-position-deletes': '0',
'total-records': '3',
}

assert summaries[1] == {
'added-data-files': '2',
'added-files-size': '10433',
'added-data-files': '3',
'added-files-size': '14471',
'added-records': '3',
'total-data-files': '4',
'total-data-files': '6',
'total-delete-files': '0',
'total-equality-deletes': '0',
'total-files-size': '20866',
'total-files-size': '28942',
'total-position-deletes': '0',
'total-records': '6',
}


@pytest.mark.adrian
def test_data_files(spark: SparkSession, session_catalog: Catalog, arrow_table_without_null: pa.Table) -> None:
# todo parametrize partition for each of the columns
@pytest.mark.integeration
def test_data_files_with_null(spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None:
identifier = "default.arrow_data_files"

try:
Expand All @@ -443,8 +382,8 @@ def test_data_files(spark: SparkSession, session_catalog: Catalog, arrow_table_w
properties={'format-version': '1'},
)

tbl.append(arrow_table_without_null)
tbl.append(arrow_table_without_null)
tbl.append(arrow_table_with_null)
tbl.append(arrow_table_with_null)

# added_data_files_count, existing_data_files_count, deleted_data_files_count
rows = spark.sql(
Expand All @@ -454,7 +393,7 @@ def test_data_files(spark: SparkSession, session_catalog: Catalog, arrow_table_w
"""
).collect()

assert [row.added_data_files_count for row in rows] == [2, 2, 2]
assert [row.added_data_files_count for row in rows] == [3, 3, 3]
assert [row.existing_data_files_count for row in rows] == [
0,
0,
Expand All @@ -463,8 +402,9 @@ def test_data_files(spark: SparkSession, session_catalog: Catalog, arrow_table_w
assert [row.deleted_data_files_count for row in rows] == [0, 0, 0]


@pytest.mark.adrian
def test_invalid_arguments(spark: SparkSession, session_catalog: Catalog, arrow_table_without_null: pa.Table) -> None:
# i think this test does not need to duplicate
@pytest.mark.integeration
def test_invalid_arguments(spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None:
identifier = "default.arrow_data_files"

try:
Expand Down

0 comments on commit b74521e

Please sign in to comment.