Skip to content

Commit

Permalink
Use self.table_metadata in transaction (#985)
Browse files Browse the repository at this point in the history
  • Loading branch information
HonahX authored and sungwy committed Aug 9, 2024
1 parent 034e892 commit feaf7e4
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 20 deletions.
30 changes: 16 additions & 14 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,10 +331,10 @@ def upgrade_table_version(self, format_version: TableVersion) -> Transaction:
if format_version not in {1, 2}:
raise ValueError(f"Unsupported table format version: {format_version}")

if format_version < self._table.metadata.format_version:
raise ValueError(f"Cannot downgrade v{self._table.metadata.format_version} table to v{format_version}")
if format_version < self.table_metadata.format_version:
raise ValueError(f"Cannot downgrade v{self.table_metadata.format_version} table to v{format_version}")

if format_version > self._table.metadata.format_version:
if format_version > self.table_metadata.format_version:
return self._apply((UpgradeFormatVersionUpdate(format_version=format_version),))

return self
Expand Down Expand Up @@ -452,7 +452,7 @@ def update_schema(self, allow_incompatible_changes: bool = False, case_sensitive
self,
allow_incompatible_changes=allow_incompatible_changes,
case_sensitive=case_sensitive,
name_mapping=self._table.name_mapping(),
name_mapping=self.table_metadata.name_mapping(),
)

def update_snapshot(self, snapshot_properties: Dict[str, str] = EMPTY_DICT) -> UpdateSnapshot:
Expand Down Expand Up @@ -489,7 +489,7 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT)
)
downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False
_check_pyarrow_schema_compatible(
self._table.schema(), provided_schema=df.schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us
self.table_metadata.schema(), provided_schema=df.schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us
)

manifest_merge_enabled = PropertyUtil.property_as_bool(
Expand All @@ -504,7 +504,7 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT)
# skip writing data files if the dataframe is empty
if df.shape[0] > 0:
data_files = _dataframe_to_data_files(
table_metadata=self._table.metadata, write_uuid=append_files.commit_uuid, df=df, io=self._table.io
table_metadata=self.table_metadata, write_uuid=append_files.commit_uuid, df=df, io=self._table.io
)
for data_file in data_files:
append_files.append_data_file(data_file)
Expand Down Expand Up @@ -548,7 +548,7 @@ def overwrite(
)
downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False
_check_pyarrow_schema_compatible(
self._table.schema(), provided_schema=df.schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us
self.table_metadata.schema(), provided_schema=df.schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us
)

self.delete(delete_filter=overwrite_filter, snapshot_properties=snapshot_properties)
Expand All @@ -557,7 +557,7 @@ def overwrite(
# skip writing data files if the dataframe is empty
if df.shape[0] > 0:
data_files = _dataframe_to_data_files(
table_metadata=self._table.metadata, write_uuid=update_snapshot.commit_uuid, df=df, io=self._table.io
table_metadata=self.table_metadata, write_uuid=update_snapshot.commit_uuid, df=df, io=self._table.io
)
for data_file in data_files:
update_snapshot.append_data_file(data_file)
Expand Down Expand Up @@ -595,7 +595,7 @@ def delete(self, delete_filter: Union[str, BooleanExpression], snapshot_properti

# Check if there are any files that require an actual rewrite of a data file
if delete_snapshot.rewrites_needed is True:
bound_delete_filter = bind(self._table.schema(), delete_filter, case_sensitive=True)
bound_delete_filter = bind(self.table_metadata.schema(), delete_filter, case_sensitive=True)
preserve_row_filter = _expression_to_complementary_pyarrow(bound_delete_filter)

files = self._scan(row_filter=delete_filter).plan_files()
Expand All @@ -614,7 +614,7 @@ def delete(self, delete_filter: Union[str, BooleanExpression], snapshot_properti
for original_file in files:
df = project_table(
tasks=[original_file],
table_metadata=self._table.metadata,
table_metadata=self.table_metadata,
io=self._table.io,
row_filter=AlwaysTrue(),
projected_schema=self.table_metadata.schema(),
Expand All @@ -629,7 +629,7 @@ def delete(self, delete_filter: Union[str, BooleanExpression], snapshot_properti
_dataframe_to_data_files(
io=self._table.io,
df=filtered_df,
table_metadata=self._table.metadata,
table_metadata=self.table_metadata,
write_uuid=commit_uuid,
counter=counter,
)
Expand Down Expand Up @@ -658,11 +658,13 @@ def add_files(self, file_paths: List[str], snapshot_properties: Dict[str, str] =
Raises:
FileNotFoundError: If the file does not exist.
"""
if self._table.name_mapping() is None:
self.set_properties(**{TableProperties.DEFAULT_NAME_MAPPING: self._table.schema().name_mapping.model_dump_json()})
if self.table_metadata.name_mapping() is None:
self.set_properties(**{
TableProperties.DEFAULT_NAME_MAPPING: self.table_metadata.schema().name_mapping.model_dump_json()
})
with self.update_snapshot(snapshot_properties=snapshot_properties).fast_append() as update_snapshot:
data_files = _parquet_files_to_data_files(
table_metadata=self._table.metadata, file_paths=file_paths, io=self._table.io
table_metadata=self.table_metadata, file_paths=file_paths, io=self._table.io
)
for data_file in data_files:
update_snapshot.append_data_file(data_file)
Expand Down
6 changes: 3 additions & 3 deletions tests/catalog/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -1421,9 +1421,9 @@ def test_write_and_evolve(catalog: SqlCatalog, format_version: int) -> None:
with txn.update_schema() as schema_txn:
schema_txn.union_by_name(pa_table_with_column.schema)

with txn.update_snapshot().fast_append() as snapshot_update:
for data_file in _dataframe_to_data_files(table_metadata=txn.table_metadata, df=pa_table_with_column, io=tbl.io):
snapshot_update.append_data_file(data_file)
txn.append(pa_table_with_column)
txn.overwrite(pa_table_with_column)
txn.delete("foo = 'a'")


@pytest.mark.parametrize(
Expand Down
6 changes: 3 additions & 3 deletions tests/integration/test_writes/test_writes.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,9 +718,9 @@ def test_write_and_evolve(session_catalog: Catalog, format_version: int) -> None
with txn.update_schema() as schema_txn:
schema_txn.union_by_name(pa_table_with_column.schema)

with txn.update_snapshot().fast_append() as snapshot_update:
for data_file in _dataframe_to_data_files(table_metadata=txn.table_metadata, df=pa_table_with_column, io=tbl.io):
snapshot_update.append_data_file(data_file)
txn.append(pa_table_with_column)
txn.overwrite(pa_table_with_column)
txn.delete("foo = 'a'")


@pytest.mark.integration
Expand Down

0 comments on commit feaf7e4

Please sign in to comment.