Skip to content

Commit

Permalink
[FEAT] Delta Lake partitioned writing (#2884)
Browse files Browse the repository at this point in the history
Some things that I will cover in follow-up PRs:
- split `table_io.py` up into multiple files
- fix partitioned writes to conform to hive style (binary encoding,
string escaping, etc)

This should not actually be blocking since partition values in the delta
log do not actually need to be encoded (Spark does not do so), just
stringified. Just don't read it as a hive table lol
  • Loading branch information
kevinzwang authored Sep 24, 2024
1 parent 02b30be commit b519944
Show file tree
Hide file tree
Showing 18 changed files with 340 additions and 190 deletions.
12 changes: 12 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ bytes = "1.6.0"
chrono = "0.4.38"
chrono-tz = "0.8.4"
comfy-table = "7.1.1"
derivative = "2.2.0"
dyn-clone = "1"
futures = "0.3.30"
html-escape = "0.2.13"
Expand Down
1 change: 1 addition & 0 deletions daft/daft/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1720,6 +1720,7 @@ class LogicalPlanBuilder:
mode: str,
version: int,
large_dtypes: bool,
partition_cols: list[str] | None = None,
io_config: IOConfig | None = None,
) -> LogicalPlanBuilder: ...
def lance_write(
Expand Down
43 changes: 27 additions & 16 deletions daft/dataframe/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -766,6 +766,7 @@ def write_iceberg(self, table: "pyiceberg.table.Table", mode: str = "append") ->
def write_deltalake(
self,
table: Union[str, pathlib.Path, "DataCatalogTable", "deltalake.DeltaTable"],
partition_cols: Optional[List[str]] = None,
mode: Literal["append", "overwrite", "error", "ignore"] = "append",
schema_mode: Optional[Literal["merge", "overwrite"]] = None,
name: Optional[str] = None,
Expand All @@ -783,6 +784,7 @@ def write_deltalake(
Args:
table (Union[str, pathlib.Path, DataCatalogTable, deltalake.DeltaTable]): Destination `Delta Lake Table <https://delta-io.github.io/delta-rs/api/delta_table/>`__ or table URI to write dataframe to.
partition_cols (List[str], optional): How to subpartition each partition further. If table exists, expected to match table's existing partitioning scheme, otherwise creates the table with specified partition columns. Defaults to None.
mode (str, optional): Operation mode of the write. `append` will add new data, `overwrite` will replace table with new data, `error` will raise an error if table already exists, and `ignore` will not write anything if table already exists. Defaults to "append".
schema_mode (str, optional): Schema mode of the write. If set to `overwrite`, allows replacing the schema of the table when doing `mode=overwrite`. Schema mode `merge` is currently not supported.
name (str, optional): User-provided identifier for this table.
Expand All @@ -802,10 +804,7 @@ def write_deltalake(
import deltalake
import pyarrow as pa
from deltalake.schema import _convert_pa_schema_to_delta
from deltalake.writer import (
try_get_deltatable,
write_deltalake_pyarrow,
)
from deltalake.writer import AddAction, try_get_deltatable, write_deltalake_pyarrow
from packaging.version import parse

from daft import from_pydict
Expand Down Expand Up @@ -861,6 +860,13 @@ def write_deltalake(
delta_schema = _convert_pa_schema_to_delta(pyarrow_schema, **large_dtypes_kwargs(large_dtypes))

if table:
if partition_cols and partition_cols != table.metadata().partition_columns:
raise ValueError(
f"Expected partition columns to match that of the existing table ({table.metadata().partition_columns}), but received: {partition_cols}"
)
else:
partition_cols = table.metadata().partition_columns

table.update_incremental()

table_schema = table.schema().to_pyarrow(as_large_types=large_dtypes)
Expand All @@ -884,42 +890,45 @@ def write_deltalake(
else:
version = 0

if partition_cols is not None:
for c in partition_cols:
if self.schema()[c].dtype == DataType.binary():
raise NotImplementedError("Binary partition columns are not yet supported for Delta Lake writes")

builder = self._builder.write_deltalake(
table_uri,
mode,
version,
large_dtypes,
io_config=io_config,
partition_cols=partition_cols,
)
write_df = DataFrame(builder)
write_df.collect()

write_result = write_df.to_pydict()
assert "data_file" in write_result
data_files = write_result["data_file"]
add_action = []
assert "add_action" in write_result
add_actions: List[AddAction] = write_result["add_action"]

operations = []
paths = []
rows = []
sizes = []

for data_file in data_files:
stats = json.loads(data_file.stats)
for add_action in add_actions:
stats = json.loads(add_action.stats)
operations.append("ADD")
paths.append(data_file.path)
paths.append(add_action.path)
rows.append(stats["numRecords"])
sizes.append(data_file.size)

add_action.append(data_file)
sizes.append(add_action.size)

if table is None:
write_deltalake_pyarrow(
table_uri,
delta_schema,
add_action,
add_actions,
mode,
[],
partition_cols or [],
name,
description,
configuration,
Expand All @@ -936,7 +945,9 @@ def write_deltalake(
rows.append(old_actions_dict["num_records"][i])
sizes.append(old_actions_dict["size_bytes"][i])

table._table.create_write_transaction(add_action, mode, [], delta_schema, None, custom_metadata)
table._table.create_write_transaction(
add_actions, mode, partition_cols or [], delta_schema, None, custom_metadata
)
table.update_incremental()

with_operations = from_pydict(
Expand Down
2 changes: 2 additions & 0 deletions daft/execution/execution_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,7 @@ class WriteDeltaLake(SingleOutputInstruction):
base_path: str
large_dtypes: bool
version: int
partition_cols: list[str] | None
io_config: IOConfig | None

def run(self, inputs: list[MicroPartition]) -> list[MicroPartition]:
Expand Down Expand Up @@ -456,6 +457,7 @@ def _handle_file_write(self, input: MicroPartition) -> MicroPartition:
large_dtypes=self.large_dtypes,
base_path=self.base_path,
version=self.version,
partition_cols=self.partition_cols,
io_config=self.io_config,
)

Expand Down
2 changes: 2 additions & 0 deletions daft/execution/physical_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def deltalake_write(
base_path: str,
large_dtypes: bool,
version: int,
partition_cols: list[str] | None,
io_config: IOConfig | None,
) -> InProgressPhysicalPlan[PartitionT]:
"""Write the results of `child_plan` into pyiceberg data files described by `write_info`."""
Expand All @@ -157,6 +158,7 @@ def deltalake_write(
base_path=base_path,
large_dtypes=large_dtypes,
version=version,
partition_cols=partition_cols,
io_config=io_config,
),
)
Expand Down
2 changes: 2 additions & 0 deletions daft/execution/rust_physical_plan_shim.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,13 +363,15 @@ def write_deltalake(
path: str,
large_dtypes: bool,
version: int,
partition_cols: list[str] | None,
io_config: IOConfig | None,
) -> physical_plan.InProgressPhysicalPlan[PartitionT]:
return physical_plan.deltalake_write(
input,
path,
large_dtypes,
version,
partition_cols,
io_config,
)

Expand Down
12 changes: 7 additions & 5 deletions daft/iceberg/iceberg_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from daft import Expression, col
from daft.table import MicroPartition
from daft.table.partitioning import PartitionedTable, partition_strings_to_path, partition_values_to_string
from daft.table.partitioning import PartitionedTable, partition_strings_to_path

if TYPE_CHECKING:
import pyarrow as pa
Expand Down Expand Up @@ -222,13 +222,15 @@ def partitioned_table_to_iceberg_iter(
partition_values = partitioned.partition_values()

if partition_values:
partition_strings = partition_values_to_string(partition_values, partition_null_fallback="null").to_pylist()
partition_values_list = partition_values.to_pylist()
partition_strings = partitioned.partition_values_str()
assert partition_strings is not None

for table, part_vals, part_strs in zip(partitioned.partitions(), partition_values_list, partition_strings):
for table, part_vals, part_strs in zip(
partitioned.partitions(), partition_values.to_pylist(), partition_strings.to_pylist()
):
iceberg_part_vals = {k: to_partition_representation(v) for k, v in part_vals.items()}
part_record = IcebergRecord(**iceberg_part_vals)
part_path = partition_strings_to_path(root_path, part_strs)
part_path = partition_strings_to_path(root_path, part_strs, partition_null_fallback="null")

arrow_table = coerce_pyarrow_table_to_schema(table.to_arrow(), schema)

Expand Down
2 changes: 2 additions & 0 deletions daft/logical/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,7 @@ def write_deltalake(
version: int,
large_dtypes: bool,
io_config: IOConfig,
partition_cols: list[str] | None = None,
) -> LogicalPlanBuilder:
columns_name = self.schema().column_names()
builder = self._builder.delta_write(
Expand All @@ -312,6 +313,7 @@ def write_deltalake(
mode,
version,
large_dtypes,
partition_cols,
io_config,
)
return LogicalPlanBuilder(builder)
Expand Down
52 changes: 31 additions & 21 deletions daft/table/partitioning.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,20 @@
from typing import Dict, List, Optional

from daft import Series
from daft.expressions import ExpressionsProjection
from daft.series import Series

from .micropartition import MicroPartition


def partition_strings_to_path(root_path: str, parts: Dict[str, str]):
postfix = "/".join(f"{key}={value}" for key, value in parts.items())
def partition_strings_to_path(
root_path: str, parts: Dict[str, str], partition_null_fallback: str = "__HIVE_DEFAULT_PARTITION__"
) -> str:
keys = parts.keys()
values = [partition_null_fallback if value is None else value for value in parts.values()]
postfix = "/".join(f"{k}={v}" for k, v in zip(keys, values))
return f"{root_path}/{postfix}"


def partition_values_to_string(
partition_values: MicroPartition, partition_null_fallback: str = "__HIVE_DEFAULT_PARTITION__"
) -> MicroPartition:
"""Convert partition values to human-readable string representation, filling nulls with `partition_null_fallback`."""
default_part = Series.from_pylist([partition_null_fallback])
pkey_names = partition_values.column_names()

partition_strings = {}

for c in pkey_names:
column = partition_values.get_column(c)
string_names = column._to_str_values()
null_filled = column.is_null().if_else(default_part, string_names)
partition_strings[c] = null_filled.to_pylist()

return MicroPartition.from_pydict(partition_strings)


class PartitionedTable:
def __init__(self, table: MicroPartition, partition_keys: Optional[ExpressionsProjection]):
self.table = table
Expand Down Expand Up @@ -63,3 +49,27 @@ def partition_values(self) -> Optional[MicroPartition]:
if self._partition_values is None:
self._create_partitions()
return self._partition_values

def partition_values_str(self) -> Optional[MicroPartition]:
"""
Returns the partition values converted to human-readable strings, keeping null values as null.
If the table is not partitioned, returns None.
"""
null_part = Series.from_pylist([None])
partition_values = self.partition_values()

if partition_values is None:
return None
else:
pkey_names = partition_values.column_names()

partition_strings = {}

for c in pkey_names:
column = partition_values.get_column(c)
string_names = column._to_str_values()
null_filled = column.is_null().if_else(null_part, string_names)
partition_strings[c] = null_filled

return MicroPartition.from_pydict(partition_strings)
Loading

0 comments on commit b519944

Please sign in to comment.