Skip to content

Commit

Permalink
[FEAT] Add time travel to read_deltalake (#3022)
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinzwang authored Oct 8, 2024
1 parent 3f37a69 commit f995792
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 4 deletions.
8 changes: 7 additions & 1 deletion daft/delta_lake/delta_lake_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,15 @@

if TYPE_CHECKING:
from collections.abc import Iterator
from datetime import datetime

logger = logging.getLogger(__name__)


class DeltaLakeScanOperator(ScanOperator):
def __init__(self, table_uri: str, storage_config: StorageConfig) -> None:
def __init__(
self, table_uri: str, storage_config: StorageConfig, version: int | str | datetime | None = None
) -> None:
super().__init__()

# Unfortunately delta-rs doesn't do very good inference of credentials for S3. Thus the current Daft behavior of passing
Expand Down Expand Up @@ -67,6 +70,9 @@ def __init__(self, table_uri: str, storage_config: StorageConfig) -> None:
table_uri, storage_options=io_config_to_storage_options(deltalake_sdk_io_config, table_uri)
)

if version is not None:
self._table.load_as_version(version)

self._storage_config = storage_config
self._schema = Schema.from_pyarrow_schema(self._table.schema().to_pyarrow())
partition_columns = set(self._table.metadata().partition_columns)
Expand Down
12 changes: 9 additions & 3 deletions daft/io/_deltalake.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,15 @@
from daft.logical.builder import LogicalPlanBuilder

if TYPE_CHECKING:
from datetime import datetime

from daft.unity_catalog import UnityCatalogTable


@PublicAPI
def read_deltalake(
table: Union[str, DataCatalogTable, "UnityCatalogTable"],
version: Optional[Union[int, str, "datetime"]] = None,
io_config: Optional["IOConfig"] = None,
_multithreaded_io: Optional[bool] = None,
) -> DataFrame:
Expand All @@ -37,8 +40,11 @@ def read_deltalake(
Args:
table: Either a URI for the Delta Lake table or a :class:`~daft.io.catalog.DataCatalogTable` instance
referencing a table in a data catalog, such as AWS Glue Data Catalog or Databricks Unity Catalog.
io_config: A custom :class:`~daft.daft.IOConfig` to use when accessing Delta Lake object storage data. Defaults to None.
_multithreaded_io: Whether to use multithreading for IO threads. Setting this to False can be helpful in reducing
version (optional): If int is passed, read the table with specified version number. Otherwise if string or datetime,
read the timestamp version of the table. Strings must be RFC 3339 and ISO 8601 date and time format.
Datetimes are assumed to be UTC timezone unless specified. By default, read the latest version of the table.
io_config (optional): A custom :class:`~daft.daft.IOConfig` to use when accessing Delta Lake object storage data. Defaults to None.
_multithreaded_io (optional): Whether to use multithreading for IO threads. Setting this to False can be helpful in reducing
the amount of system resources (number of connections and thread contention) when running in the Ray runner.
Defaults to None, which will let Daft decide based on the runner it is currently using.
Expand Down Expand Up @@ -69,7 +75,7 @@ def read_deltalake(
raise ValueError(
f"table argument must be a table URI string, DataCatalogTable or UnityCatalogTable instance, but got: {type(table)}, {table}"
)
delta_lake_operator = DeltaLakeScanOperator(table_uri, storage_config=storage_config)
delta_lake_operator = DeltaLakeScanOperator(table_uri, storage_config=storage_config, version=version)

handle = ScanOperatorHandle.from_python_scan_operator(delta_lake_operator)
builder = LogicalPlanBuilder.from_tabular_scan(scan_operator=handle)
Expand Down
22 changes: 22 additions & 0 deletions tests/io/delta_lake/test_table_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,25 @@ def test_deltalake_read_row_group_splits_with_limit(tmp_path, base_table):
df = df.limit(2)
df.collect()
assert len(df) == 2, "Length of non-materialized data when read through deltalake should be correct"


def test_deltalake_read_versioned(tmp_path, base_table):
deltalake = pytest.importorskip("deltalake")
path = tmp_path / "some_table"
deltalake.write_deltalake(path, base_table)

updated_columns = base_table.columns + [pa.array(["x", "y", "z"])]
updated_column_names = base_table.column_names + ["new_column"]
updated_table = pa.Table.from_arrays(updated_columns, names=updated_column_names)
deltalake.write_deltalake(path, updated_table, mode="overwrite", schema_mode="overwrite")

for version in [None, 1]:
df = daft.read_deltalake(str(path), version=version)
expected_schema = Schema.from_pyarrow_schema(deltalake.DeltaTable(path).schema().to_pyarrow())
assert df.schema() == expected_schema
assert_pyarrow_tables_equal(df.to_arrow(), updated_table)

df = daft.read_deltalake(str(path), version=0)
expected_schema = Schema.from_pyarrow_schema(deltalake.DeltaTable(path, version=0).schema().to_pyarrow())
assert df.schema() == expected_schema
assert_pyarrow_tables_equal(df.to_arrow(), base_table)

0 comments on commit f995792

Please sign in to comment.