From 72d4b24f76bed465ef018a9a6a090cd1128b32d5 Mon Sep 17 00:00:00 2001 From: Kevin Wang Date: Tue, 8 Oct 2024 14:12:31 -0700 Subject: [PATCH] [FEAT] Add time travel to read_deltalake --- daft/delta_lake/delta_lake_scan.py | 8 +++++++- daft/io/_deltalake.py | 12 +++++++++--- tests/io/delta_lake/test_table_read.py | 22 ++++++++++++++++++++++ 3 files changed, 38 insertions(+), 4 deletions(-) diff --git a/daft/delta_lake/delta_lake_scan.py b/daft/delta_lake/delta_lake_scan.py index eb6973f24d..56bc60bf55 100644 --- a/daft/delta_lake/delta_lake_scan.py +++ b/daft/delta_lake/delta_lake_scan.py @@ -22,12 +22,15 @@ if TYPE_CHECKING: from collections.abc import Iterator + from datetime import datetime logger = logging.getLogger(__name__) class DeltaLakeScanOperator(ScanOperator): - def __init__(self, table_uri: str, storage_config: StorageConfig) -> None: + def __init__( + self, table_uri: str, storage_config: StorageConfig, version: int | str | datetime | None = None + ) -> None: super().__init__() # Unfortunately delta-rs doesn't do very good inference of credentials for S3. Thus the current Daft behavior of passing @@ -67,6 +70,9 @@ def __init__(self, table_uri: str, storage_config: StorageConfig) -> None: table_uri, storage_options=io_config_to_storage_options(deltalake_sdk_io_config, table_uri) ) + if version is not None: + self._table.load_as_version(version) + self._storage_config = storage_config self._schema = Schema.from_pyarrow_schema(self._table.schema().to_pyarrow()) partition_columns = set(self._table.metadata().partition_columns) diff --git a/daft/io/_deltalake.py b/daft/io/_deltalake.py index c4530bcd98..7165c1a341 100644 --- a/daft/io/_deltalake.py +++ b/daft/io/_deltalake.py @@ -11,12 +11,15 @@ from daft.logical.builder import LogicalPlanBuilder if TYPE_CHECKING: + from datetime import datetime + from daft.unity_catalog import UnityCatalogTable @PublicAPI def read_deltalake( table: Union[str, DataCatalogTable, "UnityCatalogTable"], + version: Optional[Union[int, str, "datetime"]] = None, io_config: Optional["IOConfig"] = None, _multithreaded_io: Optional[bool] = None, ) -> DataFrame: @@ -37,8 +40,11 @@ def read_deltalake( Args: table: Either a URI for the Delta Lake table or a :class:`~daft.io.catalog.DataCatalogTable` instance referencing a table in a data catalog, such as AWS Glue Data Catalog or Databricks Unity Catalog. - io_config: A custom :class:`~daft.daft.IOConfig` to use when accessing Delta Lake object storage data. Defaults to None. - _multithreaded_io: Whether to use multithreading for IO threads. Setting this to False can be helpful in reducing + version (optional): If int is passed, read the table with specified version number. Otherwise if string or datetime, + read the timestamp version of the table. Strings must be RFC 3339 and ISO 8601 date and time format. + Datetimes are assumed to be UTC timezone unless specified. By default, read the latest version of the table. + io_config (optional): A custom :class:`~daft.daft.IOConfig` to use when accessing Delta Lake object storage data. Defaults to None. + _multithreaded_io (optional): Whether to use multithreading for IO threads. Setting this to False can be helpful in reducing the amount of system resources (number of connections and thread contention) when running in the Ray runner. Defaults to None, which will let Daft decide based on the runner it is currently using. @@ -69,7 +75,7 @@ def read_deltalake( raise ValueError( f"table argument must be a table URI string, DataCatalogTable or UnityCatalogTable instance, but got: {type(table)}, {table}" ) - delta_lake_operator = DeltaLakeScanOperator(table_uri, storage_config=storage_config) + delta_lake_operator = DeltaLakeScanOperator(table_uri, storage_config=storage_config, version=version) handle = ScanOperatorHandle.from_python_scan_operator(delta_lake_operator) builder = LogicalPlanBuilder.from_tabular_scan(scan_operator=handle) diff --git a/tests/io/delta_lake/test_table_read.py b/tests/io/delta_lake/test_table_read.py index 9cb5881a72..273006659f 100644 --- a/tests/io/delta_lake/test_table_read.py +++ b/tests/io/delta_lake/test_table_read.py @@ -94,3 +94,25 @@ def test_deltalake_read_row_group_splits_with_limit(tmp_path, base_table): df = df.limit(2) df.collect() assert len(df) == 2, "Length of non-materialized data when read through deltalake should be correct" + + +def test_deltalake_read_versioned(tmp_path, base_table): + deltalake = pytest.importorskip("deltalake") + path = tmp_path / "some_table" + deltalake.write_deltalake(path, base_table) + + updated_columns = base_table.columns + [pa.array(["x", "y", "z"])] + updated_column_names = base_table.column_names + ["new_column"] + updated_table = pa.Table.from_arrays(updated_columns, names=updated_column_names) + deltalake.write_deltalake(path, updated_table, mode="overwrite", schema_mode="overwrite") + + for version in [None, 1]: + df = daft.read_deltalake(str(path), version=version) + expected_schema = Schema.from_pyarrow_schema(deltalake.DeltaTable(path).schema().to_pyarrow()) + assert df.schema() == expected_schema + assert_pyarrow_tables_equal(df.to_arrow(), updated_table) + + df = daft.read_deltalake(str(path), version=0) + expected_schema = Schema.from_pyarrow_schema(deltalake.DeltaTable(path, version=0).schema().to_pyarrow()) + assert df.schema() == expected_schema + assert_pyarrow_tables_equal(df.to_arrow(), base_table)