From b519944fc6fea6ed92fc52196ffad99b9c6bd90a Mon Sep 17 00:00:00 2001 From: Kev Wang Date: Tue, 24 Sep 2024 15:10:10 -0700 Subject: [PATCH] [FEAT] Delta Lake partitioned writing (#2884) Some things that I will cover in follow-up PRs: - split `table_io.py` up into multiple files - fix partitioned writes to conform to hive style (binary encoding, string escaping, etc) This should not actually be blocking since partition values in the delta log do not actually need to be encoded (Spark does not do so), just stringified. Just don't read it as a hive table lol --- Cargo.lock | 12 ++ Cargo.toml | 1 + daft/daft/__init__.pyi | 1 + daft/dataframe/dataframe.py | 43 +++-- daft/execution/execution_step.py | 2 + daft/execution/physical_plan.py | 2 + daft/execution/rust_physical_plan_shim.py | 2 + daft/iceberg/iceberg_write.py | 12 +- daft/logical/builder.py | 2 + daft/table/partitioning.py | 52 +++--- daft/table/table_io.py | 207 +++++++++++++--------- src/daft-plan/Cargo.toml | 1 + src/daft-plan/src/builder.rs | 4 + src/daft-plan/src/logical_ops/sink.rs | 3 +- src/daft-plan/src/sink_info.rs | 88 +++------ src/daft-scheduler/src/scheduler.rs | 1 + tests/io/delta_lake/conftest.py | 2 +- tests/io/delta_lake/test_table_write.py | 95 ++++++++++ 18 files changed, 340 insertions(+), 190 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index ad43bff338..27e54099d0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2073,6 +2073,7 @@ dependencies = [ "daft-scan", "daft-schema", "daft-table", + "derivative", "indexmap 2.5.0", "itertools 0.11.0", "log", @@ -2239,6 +2240,17 @@ dependencies = [ "serde", ] +[[package]] +name = "derivative" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fcc3dd5e9e9c0b295d6e1e4d811fb6f157d5ffd784b8d202fc62eac8035a770b" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "derive_more" version = "1.0.0" diff --git a/Cargo.toml b/Cargo.toml index 53f9b894b9..39d1d17ccb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -149,6 +149,7 @@ bytes = "1.6.0" chrono = "0.4.38" chrono-tz = "0.8.4" comfy-table = "7.1.1" +derivative = "2.2.0" dyn-clone = "1" futures = "0.3.30" html-escape = "0.2.13" diff --git a/daft/daft/__init__.pyi b/daft/daft/__init__.pyi index 1a5dc99f0f..d071a3d85e 100644 --- a/daft/daft/__init__.pyi +++ b/daft/daft/__init__.pyi @@ -1720,6 +1720,7 @@ class LogicalPlanBuilder: mode: str, version: int, large_dtypes: bool, + partition_cols: list[str] | None = None, io_config: IOConfig | None = None, ) -> LogicalPlanBuilder: ... def lance_write( diff --git a/daft/dataframe/dataframe.py b/daft/dataframe/dataframe.py index a4e48caba2..6211423e94 100644 --- a/daft/dataframe/dataframe.py +++ b/daft/dataframe/dataframe.py @@ -766,6 +766,7 @@ def write_iceberg(self, table: "pyiceberg.table.Table", mode: str = "append") -> def write_deltalake( self, table: Union[str, pathlib.Path, "DataCatalogTable", "deltalake.DeltaTable"], + partition_cols: Optional[List[str]] = None, mode: Literal["append", "overwrite", "error", "ignore"] = "append", schema_mode: Optional[Literal["merge", "overwrite"]] = None, name: Optional[str] = None, @@ -783,6 +784,7 @@ def write_deltalake( Args: table (Union[str, pathlib.Path, DataCatalogTable, deltalake.DeltaTable]): Destination `Delta Lake Table `__ or table URI to write dataframe to. + partition_cols (List[str], optional): How to subpartition each partition further. If table exists, expected to match table's existing partitioning scheme, otherwise creates the table with specified partition columns. Defaults to None. mode (str, optional): Operation mode of the write. `append` will add new data, `overwrite` will replace table with new data, `error` will raise an error if table already exists, and `ignore` will not write anything if table already exists. Defaults to "append". schema_mode (str, optional): Schema mode of the write. If set to `overwrite`, allows replacing the schema of the table when doing `mode=overwrite`. Schema mode `merge` is currently not supported. name (str, optional): User-provided identifier for this table. @@ -802,10 +804,7 @@ def write_deltalake( import deltalake import pyarrow as pa from deltalake.schema import _convert_pa_schema_to_delta - from deltalake.writer import ( - try_get_deltatable, - write_deltalake_pyarrow, - ) + from deltalake.writer import AddAction, try_get_deltatable, write_deltalake_pyarrow from packaging.version import parse from daft import from_pydict @@ -861,6 +860,13 @@ def write_deltalake( delta_schema = _convert_pa_schema_to_delta(pyarrow_schema, **large_dtypes_kwargs(large_dtypes)) if table: + if partition_cols and partition_cols != table.metadata().partition_columns: + raise ValueError( + f"Expected partition columns to match that of the existing table ({table.metadata().partition_columns}), but received: {partition_cols}" + ) + else: + partition_cols = table.metadata().partition_columns + table.update_incremental() table_schema = table.schema().to_pyarrow(as_large_types=large_dtypes) @@ -884,42 +890,45 @@ def write_deltalake( else: version = 0 + if partition_cols is not None: + for c in partition_cols: + if self.schema()[c].dtype == DataType.binary(): + raise NotImplementedError("Binary partition columns are not yet supported for Delta Lake writes") + builder = self._builder.write_deltalake( table_uri, mode, version, large_dtypes, io_config=io_config, + partition_cols=partition_cols, ) write_df = DataFrame(builder) write_df.collect() write_result = write_df.to_pydict() - assert "data_file" in write_result - data_files = write_result["data_file"] - add_action = [] + assert "add_action" in write_result + add_actions: List[AddAction] = write_result["add_action"] operations = [] paths = [] rows = [] sizes = [] - for data_file in data_files: - stats = json.loads(data_file.stats) + for add_action in add_actions: + stats = json.loads(add_action.stats) operations.append("ADD") - paths.append(data_file.path) + paths.append(add_action.path) rows.append(stats["numRecords"]) - sizes.append(data_file.size) - - add_action.append(data_file) + sizes.append(add_action.size) if table is None: write_deltalake_pyarrow( table_uri, delta_schema, - add_action, + add_actions, mode, - [], + partition_cols or [], name, description, configuration, @@ -936,7 +945,9 @@ def write_deltalake( rows.append(old_actions_dict["num_records"][i]) sizes.append(old_actions_dict["size_bytes"][i]) - table._table.create_write_transaction(add_action, mode, [], delta_schema, None, custom_metadata) + table._table.create_write_transaction( + add_actions, mode, partition_cols or [], delta_schema, None, custom_metadata + ) table.update_incremental() with_operations = from_pydict( diff --git a/daft/execution/execution_step.py b/daft/execution/execution_step.py index 7693a0a84c..daa9afa289 100644 --- a/daft/execution/execution_step.py +++ b/daft/execution/execution_step.py @@ -429,6 +429,7 @@ class WriteDeltaLake(SingleOutputInstruction): base_path: str large_dtypes: bool version: int + partition_cols: list[str] | None io_config: IOConfig | None def run(self, inputs: list[MicroPartition]) -> list[MicroPartition]: @@ -456,6 +457,7 @@ def _handle_file_write(self, input: MicroPartition) -> MicroPartition: large_dtypes=self.large_dtypes, base_path=self.base_path, version=self.version, + partition_cols=self.partition_cols, io_config=self.io_config, ) diff --git a/daft/execution/physical_plan.py b/daft/execution/physical_plan.py index 273ee0dc49..220731b6b5 100644 --- a/daft/execution/physical_plan.py +++ b/daft/execution/physical_plan.py @@ -147,6 +147,7 @@ def deltalake_write( base_path: str, large_dtypes: bool, version: int, + partition_cols: list[str] | None, io_config: IOConfig | None, ) -> InProgressPhysicalPlan[PartitionT]: """Write the results of `child_plan` into pyiceberg data files described by `write_info`.""" @@ -157,6 +158,7 @@ def deltalake_write( base_path=base_path, large_dtypes=large_dtypes, version=version, + partition_cols=partition_cols, io_config=io_config, ), ) diff --git a/daft/execution/rust_physical_plan_shim.py b/daft/execution/rust_physical_plan_shim.py index 351d27f3bb..a19a3e2ad8 100644 --- a/daft/execution/rust_physical_plan_shim.py +++ b/daft/execution/rust_physical_plan_shim.py @@ -363,6 +363,7 @@ def write_deltalake( path: str, large_dtypes: bool, version: int, + partition_cols: list[str] | None, io_config: IOConfig | None, ) -> physical_plan.InProgressPhysicalPlan[PartitionT]: return physical_plan.deltalake_write( @@ -370,6 +371,7 @@ def write_deltalake( path, large_dtypes, version, + partition_cols, io_config, ) diff --git a/daft/iceberg/iceberg_write.py b/daft/iceberg/iceberg_write.py index 8bc4d1431b..0de4c950d8 100644 --- a/daft/iceberg/iceberg_write.py +++ b/daft/iceberg/iceberg_write.py @@ -5,7 +5,7 @@ from daft import Expression, col from daft.table import MicroPartition -from daft.table.partitioning import PartitionedTable, partition_strings_to_path, partition_values_to_string +from daft.table.partitioning import PartitionedTable, partition_strings_to_path if TYPE_CHECKING: import pyarrow as pa @@ -222,13 +222,15 @@ def partitioned_table_to_iceberg_iter( partition_values = partitioned.partition_values() if partition_values: - partition_strings = partition_values_to_string(partition_values, partition_null_fallback="null").to_pylist() - partition_values_list = partition_values.to_pylist() + partition_strings = partitioned.partition_values_str() + assert partition_strings is not None - for table, part_vals, part_strs in zip(partitioned.partitions(), partition_values_list, partition_strings): + for table, part_vals, part_strs in zip( + partitioned.partitions(), partition_values.to_pylist(), partition_strings.to_pylist() + ): iceberg_part_vals = {k: to_partition_representation(v) for k, v in part_vals.items()} part_record = IcebergRecord(**iceberg_part_vals) - part_path = partition_strings_to_path(root_path, part_strs) + part_path = partition_strings_to_path(root_path, part_strs, partition_null_fallback="null") arrow_table = coerce_pyarrow_table_to_schema(table.to_arrow(), schema) diff --git a/daft/logical/builder.py b/daft/logical/builder.py index 31b347295d..7f8ed96cf2 100644 --- a/daft/logical/builder.py +++ b/daft/logical/builder.py @@ -304,6 +304,7 @@ def write_deltalake( version: int, large_dtypes: bool, io_config: IOConfig, + partition_cols: list[str] | None = None, ) -> LogicalPlanBuilder: columns_name = self.schema().column_names() builder = self._builder.delta_write( @@ -312,6 +313,7 @@ def write_deltalake( mode, version, large_dtypes, + partition_cols, io_config, ) return LogicalPlanBuilder(builder) diff --git a/daft/table/partitioning.py b/daft/table/partitioning.py index fac841d346..70a590cb45 100644 --- a/daft/table/partitioning.py +++ b/daft/table/partitioning.py @@ -1,34 +1,20 @@ from typing import Dict, List, Optional +from daft import Series from daft.expressions import ExpressionsProjection -from daft.series import Series from .micropartition import MicroPartition -def partition_strings_to_path(root_path: str, parts: Dict[str, str]): - postfix = "/".join(f"{key}={value}" for key, value in parts.items()) +def partition_strings_to_path( + root_path: str, parts: Dict[str, str], partition_null_fallback: str = "__HIVE_DEFAULT_PARTITION__" +) -> str: + keys = parts.keys() + values = [partition_null_fallback if value is None else value for value in parts.values()] + postfix = "/".join(f"{k}={v}" for k, v in zip(keys, values)) return f"{root_path}/{postfix}" -def partition_values_to_string( - partition_values: MicroPartition, partition_null_fallback: str = "__HIVE_DEFAULT_PARTITION__" -) -> MicroPartition: - """Convert partition values to human-readable string representation, filling nulls with `partition_null_fallback`.""" - default_part = Series.from_pylist([partition_null_fallback]) - pkey_names = partition_values.column_names() - - partition_strings = {} - - for c in pkey_names: - column = partition_values.get_column(c) - string_names = column._to_str_values() - null_filled = column.is_null().if_else(default_part, string_names) - partition_strings[c] = null_filled.to_pylist() - - return MicroPartition.from_pydict(partition_strings) - - class PartitionedTable: def __init__(self, table: MicroPartition, partition_keys: Optional[ExpressionsProjection]): self.table = table @@ -63,3 +49,27 @@ def partition_values(self) -> Optional[MicroPartition]: if self._partition_values is None: self._create_partitions() return self._partition_values + + def partition_values_str(self) -> Optional[MicroPartition]: + """ + Returns the partition values converted to human-readable strings, keeping null values as null. + + If the table is not partitioned, returns None. + """ + null_part = Series.from_pylist([None]) + partition_values = self.partition_values() + + if partition_values is None: + return None + else: + pkey_names = partition_values.column_names() + + partition_strings = {} + + for c in pkey_names: + column = partition_values.get_column(c) + string_names = column._to_str_values() + null_filled = column.is_null().if_else(null_part, string_names) + partition_strings[c] = null_filled + + return MicroPartition.from_pydict(partition_strings) diff --git a/daft/table/table_io.py b/daft/table/table_io.py index ee366946b2..ba07fab8a4 100644 --- a/daft/table/table_io.py +++ b/daft/table/table_io.py @@ -5,7 +5,6 @@ import pathlib import random import time -from functools import partial from typing import IO, TYPE_CHECKING, Any, Iterator, Union from uuid import uuid4 @@ -24,7 +23,7 @@ StorageConfig, ) from daft.dependencies import pa, pacsv, pads, pajson, pq -from daft.expressions import ExpressionsProjection +from daft.expressions import ExpressionsProjection, col from daft.filesystem import ( _resolve_paths_and_filesystem, canonicalize_protocol, @@ -40,13 +39,14 @@ from daft.sql.sql_connection import SQLConnection from .micropartition import MicroPartition -from .partitioning import PartitionedTable, partition_strings_to_path, partition_values_to_string +from .partitioning import PartitionedTable, partition_strings_to_path FileInput = Union[pathlib.Path, str, IO[bytes]] if TYPE_CHECKING: from collections.abc import Callable, Generator + from deltalake.writer import AddAction from pyiceberg.partitioning import PartitionSpec as IcebergPartitionSpec from pyiceberg.schema import Schema as IcebergSchema from pyiceberg.table import TableProperties as IcebergTableProperties @@ -404,9 +404,10 @@ def partitioned_table_to_hive_iter(partitioned: PartitionedTable, root_path: str partition_values = partitioned.partition_values() if partition_values: - partition_strings = partition_values_to_string(partition_values).to_pylist() + partition_strings = partitioned.partition_values_str() + assert partition_strings is not None - for part_table, part_strs in zip(partitioned.partitions(), partition_strings): + for part_table, part_strs in zip(partitioned.partitions(), partition_strings.to_pylist()): part_path = partition_strings_to_path(root_path, part_strs) arrow_table = part_table.to_arrow() @@ -595,112 +596,160 @@ def write_iceberg( return visitors.to_metadata() +def partitioned_table_to_deltalake_iter( + partitioned: PartitionedTable, large_dtypes: bool +) -> Iterator[tuple[pa.Table, str, dict[str, str | None]]]: + """ + Iterates over partitions, yielding each partition as an Arrow table, along with their respective paths and partition values. + """ + from deltalake.schema import _convert_pa_schema_to_delta + + from daft.io._deltalake import large_dtypes_kwargs + + partition_values = partitioned.partition_values() + + if partition_values: + partition_keys = partition_values.column_names() + partition_strings = partitioned.partition_values_str() + assert partition_strings is not None + + for part_table, part_strs in zip(partitioned.partitions(), partition_strings.to_pylist()): + part_path = partition_strings_to_path("", part_strs) + arrow_table = part_table.to_arrow() + + # Remove partition keys from the table since they are already encoded as keys + arrow_table_no_pkeys = arrow_table.drop_columns(partition_keys) + + converted_schema = _convert_pa_schema_to_delta( + arrow_table_no_pkeys.schema, **large_dtypes_kwargs(large_dtypes) + ) + converted_arrow_table = arrow_table_no_pkeys.cast(converted_schema) + + yield converted_arrow_table, part_path, part_strs + else: + arrow_table = partitioned.table.to_arrow() + arrow_batch = _convert_pa_schema_to_delta(arrow_table.schema, **large_dtypes_kwargs(large_dtypes)) + converted_arrow_table = arrow_table.cast(arrow_batch) + + yield converted_arrow_table, "/", {} + + +class DeltaLakeWriteVisitors: + class FileVisitor: + def __init__(self, parent: DeltaLakeWriteVisitors, partition_values: dict[str, str | None]): + self.parent = parent + self.partition_values = partition_values + + def __call__(self, written_file): + import json + from datetime import datetime + + import deltalake + from deltalake.writer import AddAction, DeltaJSONEncoder, get_file_stats_from_metadata + from packaging.version import parse + + from daft.utils import get_arrow_version + + # added to get_file_stats_from_metadata in deltalake v0.17.4: non-optional "num_indexed_cols" and "columns_to_collect_stats" arguments + # https://github.com/delta-io/delta-rs/blob/353e08be0202c45334dcdceee65a8679f35de710/python/deltalake/writer.py#L725 + if parse(deltalake.__version__) < parse("0.17.4"): + file_stats_args = {} + else: + file_stats_args = {"num_indexed_cols": -1, "columns_to_collect_stats": None} + + stats = get_file_stats_from_metadata(written_file.metadata, **file_stats_args) + + # PyArrow added support for written_file.size in 9.0.0 + if get_arrow_version() >= (9, 0, 0): + size = written_file.size + elif self.parent.fs is not None: + size = self.parent.fs.get_file_info([written_file.path])[0].size + else: + size = 0 + + self.parent.add_actions.append( + AddAction( + written_file.path, + size, + self.partition_values, + int(datetime.now().timestamp() * 1000), + True, + json.dumps(stats, cls=DeltaJSONEncoder), + ) + ) + + def __init__(self, fs: pa.fs.FileSystem): + self.add_actions: list[AddAction] = [] + self.fs = fs + + def visitor(self, partition_values: dict[str, str | None]) -> DeltaLakeWriteVisitors.FileVisitor: + return self.FileVisitor(self, partition_values) + + def to_metadata(self) -> MicroPartition: + return MicroPartition.from_pydict({"add_action": self.add_actions}) + + def write_deltalake( - mp: MicroPartition, + table: MicroPartition, large_dtypes: bool, base_path: str, version: int, + partition_cols: list[str] | None = None, io_config: IOConfig | None = None, ): - import json - from datetime import datetime - - import deltalake - from deltalake.schema import convert_pyarrow_table - from deltalake.writer import ( - AddAction, - DeltaJSONEncoder, - DeltaStorageHandler, - get_partitions_from_path, - ) - from packaging.version import parse + from deltalake.writer import DeltaStorageHandler from pyarrow.fs import PyFileSystem - from daft.io._deltalake import large_dtypes_kwargs from daft.io.object_store_options import io_config_to_storage_options - from daft.utils import get_arrow_version protocol = get_protocol_from_path(base_path) canonicalized_protocol = canonicalize_protocol(protocol) - data_files: list[AddAction] = [] - - # added to get_file_stats_from_metadata in deltalake v0.17.4: non-optional "num_indexed_cols" and "columns_to_collect_stats" arguments - # https://github.com/delta-io/delta-rs/blob/353e08be0202c45334dcdceee65a8679f35de710/python/deltalake/writer.py#L725 - if parse(deltalake.__version__) < parse("0.17.4"): - get_file_stats_from_metadata = deltalake.writer.get_file_stats_from_metadata - else: - get_file_stats_from_metadata = partial( - deltalake.writer.get_file_stats_from_metadata, num_indexed_cols=-1, columns_to_collect_stats=None - ) - - def file_visitor(written_file: Any) -> None: - path, partition_values = get_partitions_from_path(written_file.path) - stats = get_file_stats_from_metadata(written_file.metadata) - - # PyArrow added support for written_file.size in 9.0.0 - if get_arrow_version() >= (9, 0, 0): - size = written_file.size - elif fs is not None: - size = fs.get_file_info([path])[0].size - else: - size = 0 - - data_files.append( - AddAction( - path, - size, - partition_values, - int(datetime.now().timestamp() * 1000), - True, - json.dumps(stats, cls=DeltaJSONEncoder), - ) - ) - is_local_fs = canonicalized_protocol == "file" io_config = get_context().daft_planning_config.default_io_config if io_config is None else io_config storage_options = io_config_to_storage_options(io_config, base_path) fs = PyFileSystem(DeltaStorageHandler(base_path, storage_options)) - arrow_table = mp.to_arrow() - arrow_batch = convert_pyarrow_table(arrow_table, **large_dtypes_kwargs(large_dtypes)) - execution_config = get_context().daft_execution_config target_row_group_size = execution_config.parquet_target_row_group_size inflation_factor = execution_config.parquet_inflation_factor target_file_size = execution_config.parquet_target_filesize - size_bytes = arrow_table.nbytes + format = pads.ParquetFileFormat() + opts = format.make_write_options(use_compliant_nested_type=False) - target_num_files = max(math.ceil(size_bytes / target_file_size / inflation_factor), 1) - num_rows = len(arrow_table) + partition_keys = ExpressionsProjection([col(c) for c in partition_cols]) if partition_cols is not None else None + partitioned = PartitionedTable(table, partition_keys) + visitors = DeltaLakeWriteVisitors(fs) - rows_per_file = max(math.ceil(num_rows / target_num_files), 1) + for part_table, part_path, part_values in partitioned_table_to_deltalake_iter(partitioned, large_dtypes): + size_bytes = part_table.nbytes - target_row_groups = max(math.ceil(size_bytes / target_row_group_size / inflation_factor), 1) - rows_per_row_group = max(min(math.ceil(num_rows / target_row_groups), rows_per_file), 1) + target_num_files = max(math.ceil(size_bytes / target_file_size / inflation_factor), 1) + num_rows = len(part_table) - format = pads.ParquetFileFormat() + rows_per_file = max(math.ceil(num_rows / target_num_files), 1) - opts = format.make_write_options(use_compliant_nested_type=False) + target_row_groups = max(math.ceil(size_bytes / target_row_group_size / inflation_factor), 1) + rows_per_row_group = max(min(math.ceil(num_rows / target_row_groups), rows_per_file), 1) - _write_tabular_arrow_table( - arrow_table=arrow_batch, - schema=None, - full_path="/", - format=format, - opts=opts, - fs=fs, - rows_per_file=rows_per_file, - rows_per_row_group=rows_per_row_group, - create_dir=is_local_fs, - file_visitor=file_visitor, - version=version, - ) + _write_tabular_arrow_table( + arrow_table=part_table, + schema=None, + full_path=part_path, + format=format, + opts=opts, + fs=fs, + rows_per_file=rows_per_file, + rows_per_row_group=rows_per_row_group, + create_dir=is_local_fs, + file_visitor=visitors.visitor(part_values), + version=version, + ) - return MicroPartition.from_pydict({"data_file": Series.from_pylist(data_files, name="data_file", pyobj="force")}) + return visitors.to_metadata() def write_lance(mp: MicroPartition, base_path: str, mode: str, io_config: IOConfig | None, kwargs: dict | None): diff --git a/src/daft-plan/Cargo.toml b/src/daft-plan/Cargo.toml index d2cd422dba..a8306394f1 100644 --- a/src/daft-plan/Cargo.toml +++ b/src/daft-plan/Cargo.toml @@ -27,6 +27,7 @@ daft-functions = {path = "../daft-functions", default-features = false} daft-scan = {path = "../daft-scan", default-features = false} daft-schema = {path = "../daft-schema", default-features = false} daft-table = {path = "../daft-table", default-features = false} +derivative = {workspace = true} indexmap = {workspace = true} itertools = {workspace = true} log = {workspace = true} diff --git a/src/daft-plan/src/builder.rs b/src/daft-plan/src/builder.rs index 0098d72405..98740a6a53 100644 --- a/src/daft-plan/src/builder.rs +++ b/src/daft-plan/src/builder.rs @@ -432,6 +432,7 @@ impl LogicalPlanBuilder { mode: String, version: i32, large_dtypes: bool, + partition_cols: Option>, io_config: Option, ) -> DaftResult { use crate::sink_info::DeltaLakeCatalogInfo; @@ -441,6 +442,7 @@ impl LogicalPlanBuilder { mode, version, large_dtypes, + partition_cols, io_config, }), catalog_columns: columns_name, @@ -752,6 +754,7 @@ impl PyLogicalPlanBuilder { mode: String, version: i32, large_dtypes: bool, + partition_cols: Option>, io_config: Option, ) -> PyResult { Ok(self @@ -762,6 +765,7 @@ impl PyLogicalPlanBuilder { mode, version, large_dtypes, + partition_cols, io_config.map(|cfg| cfg.config), )? .into()) diff --git a/src/daft-plan/src/logical_ops/sink.rs b/src/daft-plan/src/logical_ops/sink.rs index 2a23292c44..d84c654c84 100644 --- a/src/daft-plan/src/logical_ops/sink.rs +++ b/src/daft-plan/src/logical_ops/sink.rs @@ -21,6 +21,7 @@ impl Sink { pub(crate) fn try_new(input: Arc, sink_info: Arc) -> DaftResult { let schema = input.schema(); + // replace partition columns with resolved columns let sink_info = match sink_info.as_ref() { SinkInfo::OutputFileInfo(OutputFileInfo { root_dir, @@ -67,7 +68,7 @@ impl Sink { Field::new("data_file", DataType::Python), ] } - CatalogType::DeltaLake(_) => vec![Field::new("data_file", DataType::Python)], + CatalogType::DeltaLake(_) => vec![Field::new("add_action", DataType::Python)], CatalogType::Lance(_) => vec![Field::new("fragments", DataType::Python)], } } diff --git a/src/daft-plan/src/sink_info.rs b/src/daft-plan/src/sink_info.rs index b66217d8d2..02c8e05273 100644 --- a/src/daft-plan/src/sink_info.rs +++ b/src/daft-plan/src/sink_info.rs @@ -5,6 +5,7 @@ use common_io_config::IOConfig; #[cfg(feature = "python")] use common_py_serde::{deserialize_py_object, serialize_py_object}; use daft_dsl::ExprRef; +use derivative::Derivative; use itertools::Itertools; #[cfg(feature = "python")] use pyo3::PyObject; @@ -43,7 +44,8 @@ pub enum CatalogType { } #[cfg(feature = "python")] -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Derivative, Debug, Clone, Serialize, Deserialize)] +#[derivative(PartialEq, Eq, Hash)] pub struct IcebergCatalogInfo { pub table_name: String, pub table_location: String, @@ -51,41 +53,27 @@ pub struct IcebergCatalogInfo { serialize_with = "serialize_py_object", deserialize_with = "deserialize_py_object" )] + #[derivative(PartialEq = "ignore")] + #[derivative(Hash = "ignore")] pub partition_spec: PyObject, #[serde( serialize_with = "serialize_py_object", deserialize_with = "deserialize_py_object" )] + #[derivative(PartialEq = "ignore")] + #[derivative(Hash = "ignore")] pub iceberg_schema: PyObject, #[serde( serialize_with = "serialize_py_object", deserialize_with = "deserialize_py_object" )] + #[derivative(PartialEq = "ignore")] + #[derivative(Hash = "ignore")] pub iceberg_properties: PyObject, pub io_config: Option, } -#[cfg(feature = "python")] -impl PartialEq for IcebergCatalogInfo { - fn eq(&self, other: &Self) -> bool { - self.table_name == other.table_name - && self.table_location == other.table_location - && self.io_config == other.io_config - } -} -#[cfg(feature = "python")] -impl Eq for IcebergCatalogInfo {} - -#[cfg(feature = "python")] -impl Hash for IcebergCatalogInfo { - fn hash(&self, state: &mut H) { - self.table_name.hash(state); - self.table_location.hash(state); - self.io_config.hash(state); - } -} - #[cfg(feature = "python")] impl IcebergCatalogInfo { pub fn multiline_display(&self) -> Vec { @@ -101,40 +89,16 @@ impl IcebergCatalogInfo { } #[cfg(feature = "python")] -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] pub struct DeltaLakeCatalogInfo { pub path: String, pub mode: String, pub version: i32, pub large_dtypes: bool, + pub partition_cols: Option>, pub io_config: Option, } -#[cfg(feature = "python")] -impl PartialEq for DeltaLakeCatalogInfo { - fn eq(&self, other: &Self) -> bool { - self.path == other.path - && self.mode == other.mode - && self.version == other.version - && self.large_dtypes == other.large_dtypes - && self.io_config == other.io_config - } -} - -#[cfg(feature = "python")] -impl Eq for DeltaLakeCatalogInfo {} - -#[cfg(feature = "python")] -impl Hash for DeltaLakeCatalogInfo { - fn hash(&self, state: &mut H) { - self.path.hash(state); - self.mode.hash(state); - self.version.hash(state); - self.large_dtypes.hash(state); - self.io_config.hash(state); - } -} - #[cfg(feature = "python")] impl DeltaLakeCatalogInfo { pub fn multiline_display(&self) -> Vec { @@ -143,6 +107,12 @@ impl DeltaLakeCatalogInfo { res.push(format!("Mode = {}", self.mode)); res.push(format!("Version = {}", self.version)); res.push(format!("Large Dtypes = {}", self.large_dtypes)); + if let Some(ref partition_cols) = self.partition_cols { + res.push(format!( + "Partition cols = {}", + partition_cols.iter().map(|e| e.to_string()).join(", ") + )); + } match &self.io_config { None => res.push("IOConfig = None".to_string()), Some(io_config) => res.push(format!("IOConfig = {}", io_config)), @@ -152,7 +122,8 @@ impl DeltaLakeCatalogInfo { } #[cfg(feature = "python")] -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Derivative, Debug, Clone, Serialize, Deserialize)] +#[derivative(PartialEq, Eq, Hash)] pub struct LanceCatalogInfo { pub path: String, pub mode: String, @@ -161,28 +132,11 @@ pub struct LanceCatalogInfo { serialize_with = "serialize_py_object", deserialize_with = "deserialize_py_object" )] + #[derivative(PartialEq = "ignore")] + #[derivative(Hash = "ignore")] pub kwargs: PyObject, } -#[cfg(feature = "python")] -impl PartialEq for LanceCatalogInfo { - fn eq(&self, other: &Self) -> bool { - self.path == other.path && self.mode == other.mode && self.io_config == other.io_config - } -} - -#[cfg(feature = "python")] -impl Eq for LanceCatalogInfo {} - -#[cfg(feature = "python")] -impl Hash for LanceCatalogInfo { - fn hash(&self, state: &mut H) { - self.path.hash(state); - self.mode.hash(state); - self.io_config.hash(state); - } -} - #[cfg(feature = "python")] impl LanceCatalogInfo { pub fn multiline_display(&self) -> Vec { diff --git a/src/daft-scheduler/src/scheduler.rs b/src/daft-scheduler/src/scheduler.rs index 2eb66781d9..709dd8ff4d 100644 --- a/src/daft-scheduler/src/scheduler.rs +++ b/src/daft-scheduler/src/scheduler.rs @@ -207,6 +207,7 @@ fn deltalake_write( &delta_lake_info.path, delta_lake_info.large_dtypes, delta_lake_info.version, + delta_lake_info.partition_cols.clone(), delta_lake_info .io_config .as_ref() diff --git a/tests/io/delta_lake/conftest.py b/tests/io/delta_lake/conftest.py index 052d5a6d74..363351eb95 100644 --- a/tests/io/delta_lake/conftest.py +++ b/tests/io/delta_lake/conftest.py @@ -34,7 +34,7 @@ def num_partitions(request) -> int: pytest.param((lambda i: i, "a"), id="int_partitioned"), pytest.param((lambda i: i * 1.5, "b"), id="float_partitioned"), pytest.param((lambda i: f"foo_{i}", "c"), id="string_partitioned"), - pytest.param((lambda i: f"foo_{i}".encode(), "d"), id="string_partitioned"), + pytest.param((lambda i: f"foo_{i}".encode(), "d"), id="binary_partitioned"), pytest.param( (lambda i: datetime.datetime(2024, 2, i + 1), "f"), id="timestamp_partitioned", diff --git a/tests/io/delta_lake/test_table_write.py b/tests/io/delta_lake/test_table_write.py index 6dbcf539fa..6519e85d0f 100644 --- a/tests/io/delta_lake/test_table_write.py +++ b/tests/io/delta_lake/test_table_write.py @@ -1,6 +1,9 @@ from __future__ import annotations +import datetime +import decimal import sys +from pathlib import Path import pyarrow as pa import pytest @@ -180,3 +183,95 @@ def test_deltalake_write_ignore(tmp_path): expected_schema = Schema.from_pyarrow_schema(read_delta.schema().to_pyarrow()) assert df1.schema() == expected_schema assert read_delta.to_pyarrow_table() == df1.to_arrow() + + +def check_equal_both_daft_and_delta_rs(df: daft.DataFrame, path: Path, sort_order: list[tuple[str, str]]): + deltalake = pytest.importorskip("deltalake") + + arrow_df = df.to_arrow().sort_by(sort_order) + + read_daft = daft.read_deltalake(str(path)) + assert read_daft.schema() == df.schema() + assert read_daft.to_arrow().sort_by(sort_order) == arrow_df + + read_delta = deltalake.DeltaTable(str(path)) + expected_schema = Schema.from_pyarrow_schema(read_delta.schema().to_pyarrow()) + assert df.schema() == expected_schema + assert read_delta.to_pyarrow_table().cast(expected_schema.to_pyarrow_schema()).sort_by(sort_order) == arrow_df + + +@pytest.mark.parametrize( + "partition_cols,num_partitions", + [ + (["int"], 3), + (["float"], 3), + (["str"], 3), + pytest.param(["bin"], 3, marks=pytest.mark.xfail(reason="Binary partitioning is not yet supported")), + (["bool"], 3), + (["datetime"], 3), + (["date"], 3), + (["decimal"], 3), + (["int", "float"], 4), + ], +) +def test_deltalake_write_partitioned(tmp_path, partition_cols, num_partitions): + path = tmp_path / "some_table" + df = daft.from_pydict( + { + "int": [1, 1, 2, None], + "float": [1.1, 2.2, 2.2, None], + "str": ["foo", "foo", "bar", None], + "bin": [b"foo", b"foo", b"bar", None], + "bool": [True, True, False, None], + "datetime": [ + datetime.datetime(2024, 2, 10), + datetime.datetime(2024, 2, 10), + datetime.datetime(2024, 2, 11), + None, + ], + "date": [datetime.date(2024, 2, 10), datetime.date(2024, 2, 10), datetime.date(2024, 2, 11), None], + "decimal": pa.array( + [decimal.Decimal("1111.111"), decimal.Decimal("1111.111"), decimal.Decimal("2222.222"), None], + type=pa.decimal128(7, 3), + ), + } + ) + result = df.write_deltalake(str(path), partition_cols=partition_cols) + result = result.to_pydict() + assert len(result["operation"]) == num_partitions + assert all(op == "ADD" for op in result["operation"]) + assert sum(result["rows"]) == len(df) + + sort_order = [("int", "ascending"), ("float", "ascending")] + check_equal_both_daft_and_delta_rs(df, path, sort_order) + + +def test_deltalake_write_partitioned_empty(tmp_path): + path = tmp_path / "some_table" + + df = daft.from_arrow(pa.schema([("int", pa.int64()), ("string", pa.string())]).empty_table()) + + df.write_deltalake(str(path), partition_cols=["int"]) + + check_equal_both_daft_and_delta_rs(df, path, [("int", "ascending")]) + + +def test_deltalake_write_partitioned_existing_table(tmp_path): + path = tmp_path / "some_table" + + df1 = daft.from_pydict({"int": [1], "string": ["foo"]}) + result = df1.write_deltalake(str(path), partition_cols=["int"]) + result = result.to_pydict() + assert result["operation"] == ["ADD"] + assert result["rows"] == [1] + + df2 = daft.from_pydict({"int": [1, 2], "string": ["bar", "bar"]}) + with pytest.raises(ValueError): + df2.write_deltalake(str(path), partition_cols=["string"]) + + result = df2.write_deltalake(str(path)) + result = result.to_pydict() + assert result["operation"] == ["ADD", "ADD"] + assert result["rows"] == [1, 1] + + check_equal_both_daft_and_delta_rs(df1.concat(df2), path, [("int", "ascending"), ("string", "ascending")])