Skip to content

Commit 2f55d07

Browse files
[Data] - Iceberg support upsert tables + schema update + overwrite tables (#58270)
## Description - Support upserting iceberg tables for IcebergDatasink - Update schema on APPEND and UPSERT - Enable overwriting the entire table Upgrades to pyicberg 0.10.0 because it now supports upsert and overwrite functionality. Also for append, the library now handles the transaction logic implicitly so that burden can be lifted from Ray Data. ## Related issues > Link related issues: "Fixes #1234", "Closes #1234", or "Related to #1234". ## Additional information > Optional: Add implementation details, API changes, usage examples, screenshots, etc. --------- Signed-off-by: Goutam <[email protected]>
1 parent d6793ec commit 2f55d07

File tree

8 files changed

+879
-106
lines changed

8 files changed

+879
-106
lines changed

ci/lint/pydoclint-baseline.txt

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1018,10 +1018,6 @@ python/ray/data/_internal/block_batching/util.py
10181018
DOC402: Function `finalize_batches` has "yield" statements, but the docstring does not have a "Yields" section
10191019
DOC404: Function `finalize_batches` yield type(s) in docstring not consistent with the return annotation. Return annotation exists, but docstring "yields" section does not exist or has 0 type(s).
10201020
--------------------
1021-
python/ray/data/_internal/datasource/iceberg_datasink.py
1022-
DOC102: Method `IcebergDatasink.__init__`: Docstring contains more arguments than in function signature.
1023-
DOC103: Method `IcebergDatasink.__init__`: Docstring arguments are different from function arguments. (Or could be other formatting issues: https://jsh9.github.io/pydoclint/violation_codes.html#notes-on-doc103 ). Arguments in the docstring but not in the function signature: [to an iceberg table, e.g. {"commit_time": ].
1024-
--------------------
10251021
python/ray/data/_internal/datasource/lance_datasink.py
10261022
DOC101: Method `LanceDatasink.__init__`: Docstring contains fewer arguments than in function signature.
10271023
DOC103: Method `LanceDatasink.__init__`: Docstring arguments are different from function arguments. (Or could be other formatting issues: https://jsh9.github.io/pydoclint/violation_codes.html#notes-on-doc103 ). Arguments in the function signature but not in the docstring: [**kwargs: , *args: , max_rows_per_file: int, min_rows_per_file: int, mode: Literal['create', 'append', 'overwrite'], schema: Optional[pa.Schema], storage_options: Optional[Dict[str, Any]], uri: str]. Arguments in the docstring but not in the function signature: [max_rows_per_file : , min_rows_per_file : , mode : , schema : , storage_options : , uri : ].

python/ray/data/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
FileShuffleConfig,
2323
ReadTask,
2424
RowBasedFileDatasink,
25+
SaveMode,
2526
)
2627
from ray.data.iterator import DataIterator, DatasetIterator
2728
from ray.data.preprocessor import Preprocessor
@@ -131,6 +132,7 @@
131132
"NodeIdStr",
132133
"ReadTask",
133134
"RowBasedFileDatasink",
135+
"SaveMode",
134136
"Schema",
135137
"SinkMode",
136138
"TaskPoolStrategy",

python/ray/data/_internal/datasource/iceberg_datasink.py

Lines changed: 179 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -2,26 +2,27 @@
22
Module to write a Ray Dataset into an iceberg table, by using the Ray Datasink API.
33
"""
44
import logging
5-
import uuid
65
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional
76

8-
from packaging import version
9-
107
from ray.data._internal.execution.interfaces import TaskContext
8+
from ray.data._internal.savemode import SaveMode
119
from ray.data.block import Block, BlockAccessor
1210
from ray.data.datasource.datasink import Datasink, WriteResult
1311
from ray.util.annotations import DeveloperAPI
1412

1513
if TYPE_CHECKING:
14+
import pyarrow as pa
1615
from pyiceberg.catalog import Catalog
17-
from pyiceberg.manifest import DataFile
16+
from pyiceberg.table import Table
17+
18+
from ray.data.expressions import Expr
1819

1920

2021
logger = logging.getLogger(__name__)
2122

2223

2324
@DeveloperAPI
24-
class IcebergDatasink(Datasink[List["DataFile"]]):
25+
class IcebergDatasink(Datasink[List["pa.Table"]]):
2526
"""
2627
Iceberg datasink to write a Ray Dataset into an existing Iceberg table. This module
2728
heavily uses PyIceberg to write to iceberg table. All the routines in this class override
@@ -34,137 +35,223 @@ def __init__(
3435
table_identifier: str,
3536
catalog_kwargs: Optional[Dict[str, Any]] = None,
3637
snapshot_properties: Optional[Dict[str, str]] = None,
38+
mode: SaveMode = SaveMode.APPEND,
39+
overwrite_filter: Optional["Expr"] = None,
40+
upsert_kwargs: Optional[Dict[str, Any]] = None,
41+
overwrite_kwargs: Optional[Dict[str, Any]] = None,
3742
):
3843
"""
3944
Initialize the IcebergDatasink
4045
4146
Args:
42-
table_identifier: The identifier of the table to read e.g. `default.taxi_dataset`
47+
table_identifier: The identifier of the table to read such as `default.taxi_dataset`
4348
catalog_kwargs: Optional arguments to use when setting up the Iceberg
4449
catalog
45-
snapshot_properties: custom properties write to snapshot when committing
46-
to an iceberg table, e.g. {"commit_time": "2021-01-01T00:00:00Z"}
50+
snapshot_properties: Custom properties to write to snapshot summary, such as commit metadata
51+
mode: Write mode - APPEND, UPSERT, or OVERWRITE. Defaults to APPEND.
52+
- APPEND: Add new data without checking for duplicates
53+
- UPSERT: Update existing rows or insert new ones based on a join condition
54+
- OVERWRITE: Replace table data (all data or filtered subset)
55+
overwrite_filter: Optional filter for OVERWRITE mode to perform partial overwrites.
56+
Must be a Ray Data expression from `ray.data.expressions`. Only rows matching
57+
this filter are replaced. If None with OVERWRITE mode, replaces all table data.
58+
upsert_kwargs: Optional arguments to pass through to PyIceberg's table.upsert()
59+
method. Supported parameters include join_cols (List[str]),
60+
when_matched_update_all (bool), when_not_matched_insert_all (bool),
61+
case_sensitive (bool), branch (str). See PyIceberg documentation for details.
62+
overwrite_kwargs: Optional arguments to pass through to PyIceberg's table.overwrite()
63+
method. Supported parameters include case_sensitive (bool) and branch (str).
64+
See PyIceberg documentation for details.
65+
66+
Note:
67+
Schema evolution is automatically enabled. New columns in the incoming data
68+
are automatically added to the table schema.
4769
"""
4870

49-
from pyiceberg.io import FileIO
50-
from pyiceberg.table import Transaction
51-
from pyiceberg.table.metadata import TableMetadata
52-
5371
self.table_identifier = table_identifier
54-
self._catalog_kwargs = catalog_kwargs if catalog_kwargs is not None else {}
55-
self._snapshot_properties = (
56-
snapshot_properties if snapshot_properties is not None else {}
57-
)
72+
self._catalog_kwargs = (catalog_kwargs or {}).copy()
73+
self._snapshot_properties = snapshot_properties or {}
74+
self._mode = mode
75+
self._overwrite_filter = overwrite_filter
76+
self._upsert_kwargs = (upsert_kwargs or {}).copy()
77+
self._overwrite_kwargs = (overwrite_kwargs or {}).copy()
78+
79+
# Validate kwargs are only set for relevant modes
80+
if self._upsert_kwargs and self._mode != SaveMode.UPSERT:
81+
raise ValueError(
82+
f"upsert_kwargs can only be specified when mode is SaveMode.UPSERT, "
83+
f"but mode is {self._mode}"
84+
)
85+
if self._overwrite_kwargs and self._mode != SaveMode.OVERWRITE:
86+
raise ValueError(
87+
f"overwrite_kwargs can only be specified when mode is SaveMode.OVERWRITE, "
88+
f"but mode is {self._mode}"
89+
)
5890

5991
if "name" in self._catalog_kwargs:
6092
self._catalog_name = self._catalog_kwargs.pop("name")
6193
else:
6294
self._catalog_name = "default"
6395

64-
self._uuid: str = None
65-
self._io: FileIO = None
66-
self._txn: Transaction = None
67-
self._table_metadata: TableMetadata = None
96+
self._table: "Table" = None
6897

69-
# Since iceberg transaction is not pickle-able, because of the table and catalog properties
70-
# we need to exclude the transaction object during serialization and deserialization during pickle
98+
# Since iceberg table is not pickle-able, we need to exclude it during serialization
7199
def __getstate__(self) -> dict:
72-
"""Exclude `_txn` during pickling."""
100+
"""Exclude `_table` during pickling."""
73101
state = self.__dict__.copy()
74-
del state["_txn"]
102+
state.pop("_table", None)
75103
return state
76104

77105
def __setstate__(self, state: dict) -> None:
78106
self.__dict__.update(state)
79-
self._txn = None
107+
self._table = None
80108

81109
def _get_catalog(self) -> "Catalog":
82110
from pyiceberg import catalog
83111

84112
return catalog.load_catalog(self._catalog_name, **self._catalog_kwargs)
85113

86-
def on_write_start(self) -> None:
87-
"""Prepare for the transaction"""
88-
import pyiceberg
89-
from pyiceberg.table import TableProperties
114+
def _update_schema(self, incoming_schema: "pa.Schema") -> None:
115+
"""
116+
Update the table schema to accommodate incoming data using union-by-name semantics.
90117
91-
if version.parse(pyiceberg.__version__) >= version.parse("0.9.0"):
92-
from pyiceberg.utils.properties import property_as_bool
93-
else:
94-
from pyiceberg.table import PropertyUtil
118+
This automatically handles:
119+
- Adding new columns from the incoming schema
120+
- Type promotion (e.g., int32 -> int64) where compatible
121+
- Preserving existing columns not in the incoming schema
95122
96-
property_as_bool = PropertyUtil.property_as_bool
123+
Args:
124+
incoming_schema: The PyArrow schema from the incoming data
125+
"""
126+
# Use PyIceberg's update_schema API
127+
with self._table.update_schema() as update:
128+
update.union_by_name(incoming_schema)
97129

130+
# Reload table completely after schema evolution
98131
catalog = self._get_catalog()
99-
table = catalog.load_table(self.table_identifier)
100-
self._txn = table.transaction()
101-
self._io = self._txn._table.io
102-
self._table_metadata = self._txn.table_metadata
103-
self._uuid = uuid.uuid4()
104-
105-
if unsupported_partitions := [
106-
field
107-
for field in self._table_metadata.spec().fields
108-
if not field.transform.supports_pyarrow_transform
109-
]:
110-
raise ValueError(
111-
f"Not all partition types are supported for writes. Following partitions cannot be written using pyarrow: {unsupported_partitions}."
112-
)
132+
self._table = catalog.load_table(self.table_identifier)
113133

114-
self._manifest_merge_enabled = property_as_bool(
115-
self._table_metadata.properties,
116-
TableProperties.MANIFEST_MERGE_ENABLED,
117-
TableProperties.MANIFEST_MERGE_ENABLED_DEFAULT,
118-
)
134+
def on_write_start(self) -> None:
135+
"""Initialize table for writing."""
136+
catalog = self._get_catalog()
137+
self._table = catalog.load_table(self.table_identifier)
119138

120-
def write(
121-
self, blocks: Iterable[Block], ctx: TaskContext
122-
) -> WriteResult[List["DataFile"]]:
123-
from pyiceberg.io.pyarrow import (
124-
_check_pyarrow_schema_compatible,
125-
_dataframe_to_data_files,
126-
)
127-
from pyiceberg.table import DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE
128-
from pyiceberg.utils.config import Config
139+
def _collect_tables_from_blocks(self, blocks: Iterable[Block]) -> List["pa.Table"]:
140+
"""Collect PyArrow tables from blocks."""
141+
collected_tables = []
129142

130-
data_files_list: WriteResult[List["DataFile"]] = []
131143
for block in blocks:
132144
pa_table = BlockAccessor.for_block(block).to_arrow()
133145

134-
downcast_ns_timestamp_to_us = (
135-
Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False
136-
)
137-
_check_pyarrow_schema_compatible(
138-
self._table_metadata.schema(),
139-
provided_schema=pa_table.schema,
140-
downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us,
141-
)
146+
if pa_table.num_rows > 0:
147+
collected_tables.append(pa_table)
142148

143-
if pa_table.shape[0] <= 0:
144-
continue
149+
return collected_tables
145150

146-
task_uuid = uuid.uuid4()
147-
data_files = _dataframe_to_data_files(
148-
self._table_metadata, pa_table, self._io, task_uuid
149-
)
150-
data_files_list.extend(data_files)
151+
def write(self, blocks: Iterable[Block], ctx: TaskContext) -> List["pa.Table"]:
152+
"""Collect blocks as PyArrow tables for all write modes."""
153+
return self._collect_tables_from_blocks(blocks)
154+
155+
def _collect_and_concat_tables(
156+
self, write_result: WriteResult[List["pa.Table"]]
157+
) -> Optional["pa.Table"]:
158+
"""Collect and concatenate all PyArrow tables from write results."""
159+
import pyarrow as pa
160+
161+
all_tables = []
162+
for tables_batch in write_result.write_returns:
163+
all_tables.extend(tables_batch)
151164

152-
return data_files_list
165+
if not all_tables:
166+
logger.warning("No data to write")
167+
return None
153168

154-
def on_write_complete(self, write_result: WriteResult[List["DataFile"]]):
155-
update_snapshot = self._txn.update_snapshot(
156-
snapshot_properties=self._snapshot_properties
169+
return pa.concat_tables(all_tables)
170+
171+
def _complete_append(self, combined_table: "pa.Table") -> None:
172+
"""Complete APPEND mode write using PyIceberg's append API."""
173+
self._table.append(
174+
df=combined_table,
175+
snapshot_properties=self._snapshot_properties,
157176
)
158-
append_method = (
159-
update_snapshot.merge_append
160-
if self._manifest_merge_enabled
161-
else update_snapshot.fast_append
177+
logger.info(
178+
f"Appended {combined_table.num_rows} rows to {self.table_identifier}"
162179
)
163180

164-
with append_method() as append_files:
165-
append_files.commit_uuid = self._uuid
166-
for data_files in write_result.write_returns:
167-
for data_file in data_files:
168-
append_files.append_data_file(data_file)
181+
def _complete_upsert(self, combined_table: "pa.Table") -> None:
182+
"""Complete UPSERT mode write using PyIceberg's upsert API."""
183+
self._table.upsert(df=combined_table, **self._upsert_kwargs)
169184

170-
self._txn.commit_transaction()
185+
join_cols = self._upsert_kwargs.get("join_cols")
186+
if join_cols:
187+
logger.info(
188+
f"Upserted {combined_table.num_rows} rows to {self.table_identifier} "
189+
f"using join columns: {join_cols}"
190+
)
191+
else:
192+
logger.info(
193+
f"Upserted {combined_table.num_rows} rows to {self.table_identifier} "
194+
f"using table-defined identifier-field-ids"
195+
)
196+
197+
def _complete_overwrite(self, combined_table: "pa.Table") -> None:
198+
"""Complete OVERWRITE mode write using PyIceberg's overwrite API."""
199+
# Warn if user passed overwrite_filter via overwrite_kwargs
200+
if "overwrite_filter" in self._overwrite_kwargs:
201+
self._overwrite_kwargs.pop("overwrite_filter")
202+
logger.warning(
203+
"Use Ray Data's Expressions for overwrite filter instead of passing "
204+
"it via PyIceberg's overwrite_filter parameter"
205+
)
206+
207+
if self._overwrite_filter:
208+
# Partial overwrite with filter
209+
from ray.data._internal.datasource.iceberg_datasource import (
210+
_IcebergExpressionVisitor,
211+
)
212+
213+
iceberg_filter = _IcebergExpressionVisitor().visit(self._overwrite_filter)
214+
self._table.overwrite(
215+
df=combined_table,
216+
overwrite_filter=iceberg_filter,
217+
snapshot_properties=self._snapshot_properties,
218+
**self._overwrite_kwargs,
219+
)
220+
logger.info(
221+
f"Overwrote {combined_table.num_rows} rows in {self.table_identifier} "
222+
f"matching filter: {self._overwrite_filter}"
223+
)
224+
else:
225+
# Full table overwrite
226+
self._table.overwrite(
227+
df=combined_table,
228+
snapshot_properties=self._snapshot_properties,
229+
**self._overwrite_kwargs,
230+
)
231+
logger.info(
232+
f"Overwrote entire table {self.table_identifier} "
233+
f"with {combined_table.num_rows} rows"
234+
)
235+
236+
def on_write_complete(self, write_result: WriteResult[List["pa.Table"]]) -> None:
237+
"""Complete the write operation based on the configured mode."""
238+
# Collect and concatenate all PyArrow tables
239+
combined_table = self._collect_and_concat_tables(write_result)
240+
if combined_table is None:
241+
return
242+
243+
# Apply schema evolution for all modes (PyIceberg doesn't handle this automatically)
244+
self._update_schema(combined_table.schema)
245+
246+
# Execute the appropriate write operation
247+
if self._mode == SaveMode.APPEND:
248+
self._complete_append(combined_table)
249+
elif self._mode == SaveMode.UPSERT:
250+
self._complete_upsert(combined_table)
251+
elif self._mode == SaveMode.OVERWRITE:
252+
self._complete_overwrite(combined_table)
253+
else:
254+
raise ValueError(
255+
f"Unsupported write mode: {self._mode}. "
256+
f"Supported modes are: APPEND, UPSERT, OVERWRITE"
257+
)

python/ray/data/_internal/savemode.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,21 @@
55

66
@PublicAPI(stability="alpha")
77
class SaveMode(str, Enum):
8+
"""Enum of possible modes for saving/writing data."""
9+
810
APPEND = "append"
11+
"""Add new data without modifying existing data."""
12+
913
OVERWRITE = "overwrite"
14+
"""Replace all existing data with new data."""
15+
1016
IGNORE = "ignore"
17+
"""Don't write if data already exists."""
18+
1119
ERROR = "error"
20+
"""Raise an error if data already exists."""
21+
22+
UPSERT = "upsert"
23+
"""Update existing rows that match on key fields, or insert new rows.
24+
Requires identifier/key fields to be specified.
25+
"""

0 commit comments

Comments
 (0)