Skip to content

Commit

Permalink
Handle Empty RecordBatch within _task_to_record_batches (#1026)
Browse files Browse the repository at this point in the history
  • Loading branch information
sungwy authored Aug 9, 2024
1 parent 3a06237 commit 159805d
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 2 deletions.
7 changes: 5 additions & 2 deletions pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1235,9 +1235,11 @@ def _task_to_record_batches(
columns=[col.name for col in file_project_schema.columns],
)

current_index = 0
next_index = 0
batches = fragment_scanner.to_batches()
for batch in batches:
next_index = next_index + len(batch)
current_index = next_index - len(batch)
if positional_deletes:
# Create the mask of indices that we're interested in
indices = _combine_positional_deletes(positional_deletes, current_index, current_index + len(batch))
Expand All @@ -1249,11 +1251,12 @@ def _task_to_record_batches(
# https://github.com/apache/arrow/issues/39220
arrow_table = pa.Table.from_batches([batch])
arrow_table = arrow_table.filter(pyarrow_filter)
if len(arrow_table) == 0:
continue
batch = arrow_table.to_batches()[0]
yield _to_requested_schema(
projected_schema, file_project_schema, batch, downcast_ns_timestamp_to_us=True, use_large_types=use_large_types
)
current_index += len(batch)


def _task_to_table(
Expand Down
68 changes: 68 additions & 0 deletions tests/integration/test_deletes.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,74 @@ def test_delete_partitioned_table_positional_deletes(spark: SparkSession, sessio
assert tbl.scan().to_arrow().to_pydict() == {"number_partitioned": [10], "number": [20]}


@pytest.mark.integration
@pytest.mark.filterwarnings("ignore:Merge on read is not yet supported, falling back to copy-on-write")
def test_delete_partitioned_table_positional_deletes_empty_batch(spark: SparkSession, session_catalog: RestCatalog) -> None:
identifier = "default.test_delete_partitioned_table_positional_deletes_empty_batch"

run_spark_commands(
spark,
[
f"DROP TABLE IF EXISTS {identifier}",
f"""
CREATE TABLE {identifier} (
number_partitioned int,
number int
)
USING iceberg
PARTITIONED BY (number_partitioned)
TBLPROPERTIES(
'format-version' = 2,
'write.delete.mode'='merge-on-read',
'write.update.mode'='merge-on-read',
'write.merge.mode'='merge-on-read',
'write.parquet.row-group-limit'=1
)
""",
],
)

tbl = session_catalog.load_table(identifier)

arrow_table = pa.Table.from_arrays(
[
pa.array([10, 10, 10]),
pa.array([1, 2, 3]),
],
schema=pa.schema([pa.field("number_partitioned", pa.int32()), pa.field("number", pa.int32())]),
)

tbl.append(arrow_table)

assert len(tbl.scan().to_arrow()) == 3

run_spark_commands(
spark,
[
# Generate a positional delete
f"""
DELETE FROM {identifier} WHERE number = 1
""",
],
)
# Assert that there is just a single Parquet file, that has one merge on read file
tbl = tbl.refresh()

files = list(tbl.scan().plan_files())
assert len(files) == 1
assert len(files[0].delete_files) == 1

assert len(tbl.scan().to_arrow()) == 2

assert len(tbl.scan(row_filter="number_partitioned == 10").to_arrow()) == 2

assert len(tbl.scan(row_filter="number_partitioned == 1").to_arrow()) == 0

reader = tbl.scan(row_filter="number_partitioned == 1").to_arrow_batch_reader()
assert isinstance(reader, pa.RecordBatchReader)
assert len(reader.read_all()) == 0


@pytest.mark.integration
@pytest.mark.filterwarnings("ignore:Merge on read is not yet supported, falling back to copy-on-write")
def test_overwrite_partitioned_table(spark: SparkSession, session_catalog: RestCatalog) -> None:
Expand Down
28 changes: 28 additions & 0 deletions tests/integration/test_reads.py
Original file line number Diff line number Diff line change
Expand Up @@ -791,3 +791,31 @@ def test_empty_scan_ordered_str(catalog: Catalog) -> None:
table_empty_scan_ordered_str = catalog.load_table("default.test_empty_scan_ordered_str")
arrow_table = table_empty_scan_ordered_str.scan(EqualTo("id", "b")).to_arrow()
assert len(arrow_table) == 0


@pytest.mark.integration
@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")])
def test_table_scan_empty_table(catalog: Catalog) -> None:
identifier = "default.test_table_scan_empty_table"
arrow_table = pa.Table.from_arrays(
[
pa.array([]),
],
schema=pa.schema([pa.field("colA", pa.string())]),
)

try:
catalog.drop_table(identifier)
except NoSuchTableError:
pass

tbl = catalog.create_table(
identifier,
schema=arrow_table.schema,
)

tbl.append(arrow_table)

result_table = tbl.scan().to_arrow()

assert len(result_table) == 0

0 comments on commit 159805d

Please sign in to comment.