Skip to content

Commit

Permalink
Minimal support for compute kwarg
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed Jan 27, 2025
1 parent bd25331 commit 4e9ad54
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 54 deletions.
37 changes: 12 additions & 25 deletions icechunk-python/python/icechunk/xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,6 @@ def write_lazy(
if not self.writer.sources:
return

import dask

# TODO: this is all pretty dask specific at the moment
chunkmanager_store_kwargs = chunkmanager_store_kwargs or {}
chunkmanager_store_kwargs["load_stored"] = False
Expand All @@ -181,21 +179,21 @@ def write_lazy(
compute=False, chunkmanager_store_kwargs=chunkmanager_store_kwargs
) # type: ignore[no-untyped-call]

if not compute:
# do not execute the write in this case
return

# Now we tree-reduce all changesets
merged_session = stateful_store_reduce(
stored_arrays,
prefix="ice-changeset",
chunk=extract_session,
aggregate=merge_sessions,
split_every=split_every,
compute=False,
compute=True,
**chunkmanager_store_kwargs,
)
if compute:
(new_session,) = dask.compute(merged_session)
self.store.session.merge(new_session)
else:
return merged_session
self.store.session.merge(merged_session)


def to_icechunk(
Expand All @@ -211,7 +209,7 @@ def to_icechunk(
chunkmanager_store_kwargs: MutableMapping[Any, Any] | None = None,
split_every: int | None = None,
compute: bool = True,
) -> Any:
) -> None:
"""
Write an Xarray object to a group of an icechunk store.
Expand Down Expand Up @@ -283,21 +281,7 @@ def to_icechunk(
Returns
-------
None if `compute` is `True`, `dask.Delayed` otherwise.
Examples
--------
When `compute` is False, this functions returns a `dask.Delayed` object. Call `dask.compute`
on this object to execute the write and receive a Session object with the changes.
With distributed writes, you are responsible for merging these changes with your
`session` so that a commit can be properly executed
>>> session = repo.writable_session("main")
>>> future = to_icechunk(dataset, session.store, ..., compute=False)
>>> (new_session,) = dask.compute(future) # execute the write
>>> session.merge(new_session) # merge in the (possibly) remote changes with the local session
>>> session.commit("wrote some data!")
None
Notes
-----
Expand All @@ -308,6 +292,9 @@ def to_icechunk(
should be written in a separate single call to ``to_zarr()``.
- Dimensions cannot be included in both ``region`` and
``append_dim`` at the same time.
Unlike Xarray's `to_zarr`, this function does *not* return a dask Delayed object
if `compute=False`.
"""

as_dataset = make_dataset(obj)
Expand All @@ -320,7 +307,7 @@ def to_icechunk(
# write in-memory arrays
writer.write_eager()
# eagerly write dask arrays
return writer.write_lazy(
writer.write_lazy(
compute=compute, chunkmanager_store_kwargs=chunkmanager_store_kwargs
)

Expand Down
33 changes: 4 additions & 29 deletions icechunk-python/tests/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,50 +5,25 @@

import dask
import distributed
import xarray as xr
from icechunk import Repository, local_filesystem_storage
from icechunk.xarray import to_icechunk
from tests.test_xarray import create_test_data, roundtrip
from xarray.testing import assert_identical


@pytest.mark.parametrize("compute", [False, True])
def test_distributed(compute: bool) -> None:
def test_distributed() -> None:
with distributed.Client(): # type: ignore [no-untyped-call]
ds = create_test_data().chunk(dim1=3, dim2=4)
with roundtrip(ds, compute=compute) as actual:
with roundtrip(ds) as actual:
assert_identical(actual, ds)

# with pytest.raises(ValueError, match="Session cannot be serialized"):
# with roundtrip(ds, allow_distributed_write=False) as actual:
# pass


def test_distributed_workflow(tmpdir):
repo = Repository.create(
storage=local_filesystem_storage(tmpdir),
)
with distributed.Client(): # type: ignore [no-untyped-call]
ds = create_test_data().chunk(dim1=3, dim2=4)
session = repo.writable_session("main")
future = to_icechunk(ds, session.store, compute=False, mode="w")

session = repo.writable_session("main")
(new_session,) = dask.compute(future)
session.merge(new_session)
session.commit("foo")

roundtripped = xr.open_zarr(
repo.readonly_session(branch="main").store, consolidated=False
)
assert_identical(roundtripped, ds)


@pytest.mark.parametrize("compute", [False, True])
def test_threaded(compute: bool) -> None:
def test_threaded() -> None:
with dask.config.set(scheduler="threads"):
ds = create_test_data().chunk(dim1=3, dim2=4)
with roundtrip(ds, compute=compute) as actual:
with roundtrip(ds) as actual:
assert_identical(actual, ds)
# with roundtrip(ds, allow_distributed_write=False) as actual:
# assert_identical(actual, ds)

0 comments on commit 4e9ad54

Please sign in to comment.