22Module to write a Ray Dataset into an iceberg table, by using the Ray Datasink API.
33"""
44import logging
5- import uuid
65from typing import TYPE_CHECKING , Any , Dict , Iterable , List , Optional
76
8- from packaging import version
9-
107from ray .data ._internal .execution .interfaces import TaskContext
8+ from ray .data ._internal .savemode import SaveMode
119from ray .data .block import Block , BlockAccessor
1210from ray .data .datasource .datasink import Datasink , WriteResult
1311from ray .util .annotations import DeveloperAPI
1412
1513if TYPE_CHECKING :
14+ import pyarrow as pa
1615 from pyiceberg .catalog import Catalog
17- from pyiceberg .manifest import DataFile
16+ from pyiceberg .table import Table
17+
18+ from ray .data .expressions import Expr
1819
1920
2021logger = logging .getLogger (__name__ )
2122
2223
2324@DeveloperAPI
24- class IcebergDatasink (Datasink [List ["DataFile " ]]):
25+ class IcebergDatasink (Datasink [List ["pa.Table " ]]):
2526 """
2627 Iceberg datasink to write a Ray Dataset into an existing Iceberg table. This module
2728 heavily uses PyIceberg to write to iceberg table. All the routines in this class override
@@ -34,137 +35,223 @@ def __init__(
3435 table_identifier : str ,
3536 catalog_kwargs : Optional [Dict [str , Any ]] = None ,
3637 snapshot_properties : Optional [Dict [str , str ]] = None ,
38+ mode : SaveMode = SaveMode .APPEND ,
39+ overwrite_filter : Optional ["Expr" ] = None ,
40+ upsert_kwargs : Optional [Dict [str , Any ]] = None ,
41+ overwrite_kwargs : Optional [Dict [str , Any ]] = None ,
3742 ):
3843 """
3944 Initialize the IcebergDatasink
4045
4146 Args:
42- table_identifier: The identifier of the table to read e.g. `default.taxi_dataset`
47+ table_identifier: The identifier of the table to read such as `default.taxi_dataset`
4348 catalog_kwargs: Optional arguments to use when setting up the Iceberg
4449 catalog
45- snapshot_properties: custom properties write to snapshot when committing
46- to an iceberg table, e.g. {"commit_time": "2021-01-01T00:00:00Z"}
50+ snapshot_properties: Custom properties to write to snapshot summary, such as commit metadata
51+ mode: Write mode - APPEND, UPSERT, or OVERWRITE. Defaults to APPEND.
52+ - APPEND: Add new data without checking for duplicates
53+ - UPSERT: Update existing rows or insert new ones based on a join condition
54+ - OVERWRITE: Replace table data (all data or filtered subset)
55+ overwrite_filter: Optional filter for OVERWRITE mode to perform partial overwrites.
56+ Must be a Ray Data expression from `ray.data.expressions`. Only rows matching
57+ this filter are replaced. If None with OVERWRITE mode, replaces all table data.
58+ upsert_kwargs: Optional arguments to pass through to PyIceberg's table.upsert()
59+ method. Supported parameters include join_cols (List[str]),
60+ when_matched_update_all (bool), when_not_matched_insert_all (bool),
61+ case_sensitive (bool), branch (str). See PyIceberg documentation for details.
62+ overwrite_kwargs: Optional arguments to pass through to PyIceberg's table.overwrite()
63+ method. Supported parameters include case_sensitive (bool) and branch (str).
64+ See PyIceberg documentation for details.
65+
66+ Note:
67+ Schema evolution is automatically enabled. New columns in the incoming data
68+ are automatically added to the table schema.
4769 """
4870
49- from pyiceberg .io import FileIO
50- from pyiceberg .table import Transaction
51- from pyiceberg .table .metadata import TableMetadata
52-
5371 self .table_identifier = table_identifier
54- self ._catalog_kwargs = catalog_kwargs if catalog_kwargs is not None else {}
55- self ._snapshot_properties = (
56- snapshot_properties if snapshot_properties is not None else {}
57- )
72+ self ._catalog_kwargs = (catalog_kwargs or {}).copy ()
73+ self ._snapshot_properties = snapshot_properties or {}
74+ self ._mode = mode
75+ self ._overwrite_filter = overwrite_filter
76+ self ._upsert_kwargs = (upsert_kwargs or {}).copy ()
77+ self ._overwrite_kwargs = (overwrite_kwargs or {}).copy ()
78+
79+ # Validate kwargs are only set for relevant modes
80+ if self ._upsert_kwargs and self ._mode != SaveMode .UPSERT :
81+ raise ValueError (
82+ f"upsert_kwargs can only be specified when mode is SaveMode.UPSERT, "
83+ f"but mode is { self ._mode } "
84+ )
85+ if self ._overwrite_kwargs and self ._mode != SaveMode .OVERWRITE :
86+ raise ValueError (
87+ f"overwrite_kwargs can only be specified when mode is SaveMode.OVERWRITE, "
88+ f"but mode is { self ._mode } "
89+ )
5890
5991 if "name" in self ._catalog_kwargs :
6092 self ._catalog_name = self ._catalog_kwargs .pop ("name" )
6193 else :
6294 self ._catalog_name = "default"
6395
64- self ._uuid : str = None
65- self ._io : FileIO = None
66- self ._txn : Transaction = None
67- self ._table_metadata : TableMetadata = None
96+ self ._table : "Table" = None
6897
69- # Since iceberg transaction is not pickle-able, because of the table and catalog properties
70- # we need to exclude the transaction object during serialization and deserialization during pickle
98+ # Since iceberg table is not pickle-able, we need to exclude it during serialization
7199 def __getstate__ (self ) -> dict :
72- """Exclude `_txn ` during pickling."""
100+ """Exclude `_table ` during pickling."""
73101 state = self .__dict__ .copy ()
74- del state [ "_txn" ]
102+ state . pop ( "_table" , None )
75103 return state
76104
77105 def __setstate__ (self , state : dict ) -> None :
78106 self .__dict__ .update (state )
79- self ._txn = None
107+ self ._table = None
80108
81109 def _get_catalog (self ) -> "Catalog" :
82110 from pyiceberg import catalog
83111
84112 return catalog .load_catalog (self ._catalog_name , ** self ._catalog_kwargs )
85113
86- def on_write_start (self ) -> None :
87- """Prepare for the transaction"""
88- import pyiceberg
89- from pyiceberg .table import TableProperties
114+ def _update_schema (self , incoming_schema : "pa.Schema" ) -> None :
115+ """
116+ Update the table schema to accommodate incoming data using union-by-name semantics.
90117
91- if version . parse ( pyiceberg . __version__ ) >= version . parse ( "0.9.0" ) :
92- from pyiceberg . utils . properties import property_as_bool
93- else :
94- from pyiceberg . table import PropertyUtil
118+ This automatically handles :
119+ - Adding new columns from the incoming schema
120+ - Type promotion (e.g., int32 -> int64) where compatible
121+ - Preserving existing columns not in the incoming schema
95122
96- property_as_bool = PropertyUtil .property_as_bool
123+ Args:
124+ incoming_schema: The PyArrow schema from the incoming data
125+ """
126+ # Use PyIceberg's update_schema API
127+ with self ._table .update_schema () as update :
128+ update .union_by_name (incoming_schema )
97129
130+ # Reload table completely after schema evolution
98131 catalog = self ._get_catalog ()
99- table = catalog .load_table (self .table_identifier )
100- self ._txn = table .transaction ()
101- self ._io = self ._txn ._table .io
102- self ._table_metadata = self ._txn .table_metadata
103- self ._uuid = uuid .uuid4 ()
104-
105- if unsupported_partitions := [
106- field
107- for field in self ._table_metadata .spec ().fields
108- if not field .transform .supports_pyarrow_transform
109- ]:
110- raise ValueError (
111- f"Not all partition types are supported for writes. Following partitions cannot be written using pyarrow: { unsupported_partitions } ."
112- )
132+ self ._table = catalog .load_table (self .table_identifier )
113133
114- self ._manifest_merge_enabled = property_as_bool (
115- self ._table_metadata .properties ,
116- TableProperties .MANIFEST_MERGE_ENABLED ,
117- TableProperties .MANIFEST_MERGE_ENABLED_DEFAULT ,
118- )
134+ def on_write_start (self ) -> None :
135+ """Initialize table for writing."""
136+ catalog = self ._get_catalog ()
137+ self ._table = catalog .load_table (self .table_identifier )
119138
120- def write (
121- self , blocks : Iterable [Block ], ctx : TaskContext
122- ) -> WriteResult [List ["DataFile" ]]:
123- from pyiceberg .io .pyarrow import (
124- _check_pyarrow_schema_compatible ,
125- _dataframe_to_data_files ,
126- )
127- from pyiceberg .table import DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE
128- from pyiceberg .utils .config import Config
139+ def _collect_tables_from_blocks (self , blocks : Iterable [Block ]) -> List ["pa.Table" ]:
140+ """Collect PyArrow tables from blocks."""
141+ collected_tables = []
129142
130- data_files_list : WriteResult [List ["DataFile" ]] = []
131143 for block in blocks :
132144 pa_table = BlockAccessor .for_block (block ).to_arrow ()
133145
134- downcast_ns_timestamp_to_us = (
135- Config ().get_bool (DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE ) or False
136- )
137- _check_pyarrow_schema_compatible (
138- self ._table_metadata .schema (),
139- provided_schema = pa_table .schema ,
140- downcast_ns_timestamp_to_us = downcast_ns_timestamp_to_us ,
141- )
146+ if pa_table .num_rows > 0 :
147+ collected_tables .append (pa_table )
142148
143- if pa_table .shape [0 ] <= 0 :
144- continue
149+ return collected_tables
145150
146- task_uuid = uuid .uuid4 ()
147- data_files = _dataframe_to_data_files (
148- self ._table_metadata , pa_table , self ._io , task_uuid
149- )
150- data_files_list .extend (data_files )
151+ def write (self , blocks : Iterable [Block ], ctx : TaskContext ) -> List ["pa.Table" ]:
152+ """Collect blocks as PyArrow tables for all write modes."""
153+ return self ._collect_tables_from_blocks (blocks )
154+
155+ def _collect_and_concat_tables (
156+ self , write_result : WriteResult [List ["pa.Table" ]]
157+ ) -> Optional ["pa.Table" ]:
158+ """Collect and concatenate all PyArrow tables from write results."""
159+ import pyarrow as pa
160+
161+ all_tables = []
162+ for tables_batch in write_result .write_returns :
163+ all_tables .extend (tables_batch )
151164
152- return data_files_list
165+ if not all_tables :
166+ logger .warning ("No data to write" )
167+ return None
153168
154- def on_write_complete (self , write_result : WriteResult [List ["DataFile" ]]):
155- update_snapshot = self ._txn .update_snapshot (
156- snapshot_properties = self ._snapshot_properties
169+ return pa .concat_tables (all_tables )
170+
171+ def _complete_append (self , combined_table : "pa.Table" ) -> None :
172+ """Complete APPEND mode write using PyIceberg's append API."""
173+ self ._table .append (
174+ df = combined_table ,
175+ snapshot_properties = self ._snapshot_properties ,
157176 )
158- append_method = (
159- update_snapshot .merge_append
160- if self ._manifest_merge_enabled
161- else update_snapshot .fast_append
177+ logger .info (
178+ f"Appended { combined_table .num_rows } rows to { self .table_identifier } "
162179 )
163180
164- with append_method () as append_files :
165- append_files .commit_uuid = self ._uuid
166- for data_files in write_result .write_returns :
167- for data_file in data_files :
168- append_files .append_data_file (data_file )
181+ def _complete_upsert (self , combined_table : "pa.Table" ) -> None :
182+ """Complete UPSERT mode write using PyIceberg's upsert API."""
183+ self ._table .upsert (df = combined_table , ** self ._upsert_kwargs )
169184
170- self ._txn .commit_transaction ()
185+ join_cols = self ._upsert_kwargs .get ("join_cols" )
186+ if join_cols :
187+ logger .info (
188+ f"Upserted { combined_table .num_rows } rows to { self .table_identifier } "
189+ f"using join columns: { join_cols } "
190+ )
191+ else :
192+ logger .info (
193+ f"Upserted { combined_table .num_rows } rows to { self .table_identifier } "
194+ f"using table-defined identifier-field-ids"
195+ )
196+
197+ def _complete_overwrite (self , combined_table : "pa.Table" ) -> None :
198+ """Complete OVERWRITE mode write using PyIceberg's overwrite API."""
199+ # Warn if user passed overwrite_filter via overwrite_kwargs
200+ if "overwrite_filter" in self ._overwrite_kwargs :
201+ self ._overwrite_kwargs .pop ("overwrite_filter" )
202+ logger .warning (
203+ "Use Ray Data's Expressions for overwrite filter instead of passing "
204+ "it via PyIceberg's overwrite_filter parameter"
205+ )
206+
207+ if self ._overwrite_filter :
208+ # Partial overwrite with filter
209+ from ray .data ._internal .datasource .iceberg_datasource import (
210+ _IcebergExpressionVisitor ,
211+ )
212+
213+ iceberg_filter = _IcebergExpressionVisitor ().visit (self ._overwrite_filter )
214+ self ._table .overwrite (
215+ df = combined_table ,
216+ overwrite_filter = iceberg_filter ,
217+ snapshot_properties = self ._snapshot_properties ,
218+ ** self ._overwrite_kwargs ,
219+ )
220+ logger .info (
221+ f"Overwrote { combined_table .num_rows } rows in { self .table_identifier } "
222+ f"matching filter: { self ._overwrite_filter } "
223+ )
224+ else :
225+ # Full table overwrite
226+ self ._table .overwrite (
227+ df = combined_table ,
228+ snapshot_properties = self ._snapshot_properties ,
229+ ** self ._overwrite_kwargs ,
230+ )
231+ logger .info (
232+ f"Overwrote entire table { self .table_identifier } "
233+ f"with { combined_table .num_rows } rows"
234+ )
235+
236+ def on_write_complete (self , write_result : WriteResult [List ["pa.Table" ]]) -> None :
237+ """Complete the write operation based on the configured mode."""
238+ # Collect and concatenate all PyArrow tables
239+ combined_table = self ._collect_and_concat_tables (write_result )
240+ if combined_table is None :
241+ return
242+
243+ # Apply schema evolution for all modes (PyIceberg doesn't handle this automatically)
244+ self ._update_schema (combined_table .schema )
245+
246+ # Execute the appropriate write operation
247+ if self ._mode == SaveMode .APPEND :
248+ self ._complete_append (combined_table )
249+ elif self ._mode == SaveMode .UPSERT :
250+ self ._complete_upsert (combined_table )
251+ elif self ._mode == SaveMode .OVERWRITE :
252+ self ._complete_overwrite (combined_table )
253+ else :
254+ raise ValueError (
255+ f"Unsupported write mode: { self ._mode } . "
256+ f"Supported modes are: APPEND, UPSERT, OVERWRITE"
257+ )
0 commit comments