Skip to content

Commit

Permalink
[BUG] Writes from empty partitions should return empty micropartition…
Browse files Browse the repository at this point in the history
…s with non-null schema (#2952)

If one partition is empty the write will return a list of file paths /
partition cols but the data type is NULL. This is problematic because it
will cause schema mismatch with other partitions that did have writes.

```
import daft

df = (
    daft.from_pydict({"foo": [1, 2, 3], "bar": ["a", "b", "c"]})
    .into_partitions(4)
    .write_parquet("z", partition_cols=["bar"])
)
print(df)

daft.exceptions.DaftCoreException: DaftError::SchemaMismatch MicroPartition concat requires all schemas to match, ╭─────────────┬──────╮
│ Column Name ┆ Type │
╞═════════════╪══════╡
│ path        ┆ Utf8 │
├╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌┤
│ bar         ┆ Utf8 │
╰─────────────┴──────╯
 vs ╭─────────────┬──────╮
│ Column Name ┆ Type │
╞═════════════╪══════╡
│ path        ┆ Null │
╰─────────────┴──────╯
```

---------

Co-authored-by: Colin Ho <[email protected]>
  • Loading branch information
colin-ho and Colin Ho authored Sep 30, 2024
1 parent f1194b5 commit f10d4da
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 8 deletions.
7 changes: 6 additions & 1 deletion daft/iceberg/iceberg_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from typing import TYPE_CHECKING, Any, Iterator, List, Tuple

from daft import Expression, col
from daft.datatype import DataType
from daft.io.common import _get_schema_from_dict
from daft.table import MicroPartition
from daft.table.partitioning import PartitionedTable, partition_strings_to_path

Expand Down Expand Up @@ -211,7 +213,10 @@ def visitor(self, partition_record: "IcebergRecord") -> "IcebergWriteVisitors.Fi
return self.FileVisitor(self, partition_record)

def to_metadata(self) -> MicroPartition:
return MicroPartition.from_pydict({"data_file": self.data_files})
col_name = "data_file"
if len(self.data_files) == 0:
return MicroPartition.empty(_get_schema_from_dict({col_name: DataType.python()}))
return MicroPartition.from_pydict({col_name: self.data_files})


def partitioned_table_to_iceberg_iter(
Expand Down
22 changes: 15 additions & 7 deletions daft/table/table_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,15 @@
PythonStorageConfig,
StorageConfig,
)
from daft.datatype import DataType
from daft.dependencies import pa, pacsv, pads, pajson, pq
from daft.expressions import ExpressionsProjection, col
from daft.filesystem import (
_resolve_paths_and_filesystem,
canonicalize_protocol,
get_protocol_from_path,
)
from daft.io.common import _get_schema_from_dict
from daft.logical.schema import Schema
from daft.runners.partitioning import (
TableParseCSVOptions,
Expand Down Expand Up @@ -426,16 +428,22 @@ def __call__(self, written_file):
self.parent.paths.append(written_file.path)
self.parent.partition_indices.append(self.idx)

def __init__(self, partition_values: MicroPartition | None, path_key: str = "path"):
def __init__(self, partition_values: MicroPartition | None, schema: Schema):
self.paths: list[str] = []
self.partition_indices: list[int] = []
self.partition_values = partition_values
self.path_key = path_key
self.path_key = schema.column_names()[
0
] # I kept this from our original code, but idk why it's the first column name -kevin
self.schema = schema

def visitor(self, partition_idx: int) -> TabularWriteVisitors.FileVisitor:
return self.FileVisitor(self, partition_idx)

def to_metadata(self) -> MicroPartition:
if len(self.paths) == 0:
return MicroPartition.empty(self.schema)

metadata: dict[str, Any] = {self.path_key: self.paths}

if self.partition_values:
Expand Down Expand Up @@ -488,10 +496,7 @@ def write_tabular(

partitioned = PartitionedTable(table, partition_cols)

# I kept this from our original code, but idk why it's the first column name -kevin
path_key = schema.column_names()[0]

visitors = TabularWriteVisitors(partitioned.partition_values(), path_key)
visitors = TabularWriteVisitors(partitioned.partition_values(), schema)

for i, (part_table, part_path) in enumerate(partitioned_table_to_hive_iter(partitioned, resolved_path)):
size_bytes = part_table.nbytes
Expand Down Expand Up @@ -686,7 +691,10 @@ def visitor(self, partition_values: dict[str, str | None]) -> DeltaLakeWriteVisi
return self.FileVisitor(self, partition_values)

def to_metadata(self) -> MicroPartition:
return MicroPartition.from_pydict({"add_action": self.add_actions})
col_name = "add_action"
if len(self.add_actions) == 0:
return MicroPartition.empty(_get_schema_from_dict({col_name: DataType.python()}))
return MicroPartition.from_pydict({col_name: self.add_actions})


def write_deltalake(
Expand Down
40 changes: 40 additions & 0 deletions tests/cookbook/test_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,26 @@ def test_parquet_write_multifile_with_partitioning(tmp_path, smaller_parquet_tar
assert readback["y"] == [y % 2 for y in data["x"]]


def test_parquet_write_with_some_empty_partitions(tmp_path):
data = {"x": [1, 2, 3], "y": ["a", "b", "c"]}
output_files = daft.from_pydict(data).into_partitions(4).write_parquet(tmp_path)

assert len(output_files) == 3

read_back = daft.read_parquet(tmp_path.as_posix() + "/**/*.parquet").sort("x").to_pydict()
assert read_back == data


def test_parquet_partitioned_write_with_some_empty_partitions(tmp_path):
data = {"x": [1, 2, 3], "y": ["a", "b", "c"]}
output_files = daft.from_pydict(data).into_partitions(4).write_parquet(tmp_path, partition_cols=["x"])

assert len(output_files) == 3

read_back = daft.read_parquet(tmp_path.as_posix() + "/**/*.parquet").sort("x").to_pydict()
assert read_back == data


def test_csv_write(tmp_path):
df = daft.read_csv(COOKBOOK_DATA_CSV)

Expand Down Expand Up @@ -262,3 +282,23 @@ def test_empty_csv_write_with_partitioning(tmp_path):

assert len(pd_df) == 1
assert len(pd_df._preview.preview_partition) == 1


def test_csv_write_with_some_empty_partitions(tmp_path):
data = {"x": [1, 2, 3], "y": ["a", "b", "c"]}
output_files = daft.from_pydict(data).into_partitions(4).write_csv(tmp_path)

assert len(output_files) == 3

read_back = daft.read_csv(tmp_path.as_posix() + "/**/*.csv").sort("x").to_pydict()
assert read_back == data


def test_csv_partitioned_write_with_some_empty_partitions(tmp_path):
data = {"x": [1, 2, 3], "y": ["a", "b", "c"]}
output_files = daft.from_pydict(data).into_partitions(4).write_csv(tmp_path, partition_cols=["x"])

assert len(output_files) == 3

read_back = daft.read_csv(tmp_path.as_posix() + "/**/*.csv").sort("x").to_pydict()
assert read_back == data
25 changes: 25 additions & 0 deletions tests/io/delta_lake/test_table_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,21 @@ def test_deltalake_write_ignore(tmp_path):
assert read_delta.to_pyarrow_table() == df1.to_arrow()


def test_deltalake_write_with_empty_partition(tmp_path, base_table):
deltalake = pytest.importorskip("deltalake")
path = tmp_path / "some_table"
df = daft.from_arrow(base_table).into_partitions(4)
result = df.write_deltalake(str(path))
result = result.to_pydict()
assert result["operation"] == ["ADD", "ADD", "ADD"]
assert result["rows"] == [1, 1, 1]

read_delta = deltalake.DeltaTable(str(path))
expected_schema = Schema.from_pyarrow_schema(read_delta.schema().to_pyarrow())
assert df.schema() == expected_schema
assert read_delta.to_pyarrow_table() == base_table


def check_equal_both_daft_and_delta_rs(df: daft.DataFrame, path: Path, sort_order: list[tuple[str, str]]):
deltalake = pytest.importorskip("deltalake")

Expand Down Expand Up @@ -256,6 +271,16 @@ def test_deltalake_write_partitioned_empty(tmp_path):
check_equal_both_daft_and_delta_rs(df, path, [("int", "ascending")])


def test_deltalake_write_partitioned_some_empty(tmp_path):
path = tmp_path / "some_table"

df = daft.from_pydict({"int": [1, 2, 3, None], "string": ["foo", "foo", "bar", None]}).into_partitions(5)

df.write_deltalake(str(path), partition_cols=["int"])

check_equal_both_daft_and_delta_rs(df, path, [("int", "ascending")])


def test_deltalake_write_partitioned_existing_table(tmp_path):
path = tmp_path / "some_table"

Expand Down
12 changes: 12 additions & 0 deletions tests/io/iceberg/test_iceberg_writes.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,18 @@ def test_read_after_write_nested_fields(local_catalog):
assert as_arrow == read_back.to_arrow()


def test_read_after_write_with_empty_partition(local_catalog):
df = daft.from_pydict({"x": [1, 2, 3]}).into_partitions(4)
as_arrow = df.to_arrow()
table = local_catalog.create_table("default.test", as_arrow.schema)
result = df.write_iceberg(table)
as_dict = result.to_pydict()
assert as_dict["operation"] == ["ADD", "ADD", "ADD"]
assert as_dict["rows"] == [1, 1, 1]
read_back = daft.read_iceberg(table)
assert as_arrow == read_back.to_arrow()


@pytest.fixture
def complex_table() -> tuple[pa.Table, Schema]:
table = pa.table(
Expand Down

0 comments on commit f10d4da

Please sign in to comment.