Skip to content

Commit

Permalink
full compute support
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed Jan 27, 2025
1 parent e7f5eeb commit bd25331
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 14 deletions.
43 changes: 35 additions & 8 deletions icechunk-python/python/icechunk/xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,9 +154,11 @@ def write_eager(self) -> None:

def write_lazy(
self,
*,
compute: bool = True,
chunkmanager_store_kwargs: MutableMapping[Any, Any] | None = None,
split_every: int | None = None,
) -> None:
) -> Any:
"""
Write lazy arrays (e.g. dask) to store.
"""
Expand All @@ -166,6 +168,9 @@ def write_lazy(
if not self.writer.sources:
return

import dask

# TODO: this is all pretty dask specific at the moment
chunkmanager_store_kwargs = chunkmanager_store_kwargs or {}
chunkmanager_store_kwargs["load_stored"] = False
chunkmanager_store_kwargs["return_stored"] = True
Expand All @@ -183,10 +188,14 @@ def write_lazy(
chunk=extract_session,
aggregate=merge_sessions,
split_every=split_every,
compute=True,
compute=False,
**chunkmanager_store_kwargs,
)
self.store.session.merge(merged_session)
if compute:
(new_session,) = dask.compute(merged_session)
self.store.session.merge(new_session)
else:
return merged_session


def to_icechunk(
Expand All @@ -201,7 +210,8 @@ def to_icechunk(
encoding: Mapping[Any, Any] | None = None,
chunkmanager_store_kwargs: MutableMapping[Any, Any] | None = None,
split_every: int | None = None,
) -> None:
compute: bool = True,
) -> Any:
"""
Write an Xarray object to a group of an icechunk store.
Expand Down Expand Up @@ -268,10 +278,26 @@ def to_icechunk(
`dask.array.store()`. Experimental API that should not be relied upon.
split_every: int, optional
Number of tasks to merge at every level of the tree reduction.
compute: bool
Whether to eagerly write chunked arrays.
Returns
-------
None
None if `compute` is `True`, `dask.Delayed` otherwise.
Examples
--------
When `compute` is False, this functions returns a `dask.Delayed` object. Call `dask.compute`
on this object to execute the write and receive a Session object with the changes.
With distributed writes, you are responsible for merging these changes with your
`session` so that a commit can be properly executed
>>> session = repo.writable_session("main")
>>> future = to_icechunk(dataset, session.store, ..., compute=False)
>>> (new_session,) = dask.compute(future) # execute the write
>>> session.merge(new_session) # merge in the (possibly) remote changes with the local session
>>> session.commit("wrote some data!")
Notes
-----
Expand All @@ -281,8 +307,7 @@ def to_icechunk(
least one dimension in common with the region. Other variables
should be written in a separate single call to ``to_zarr()``.
- Dimensions cannot be included in both ``region`` and
``append_dim`` at the same time. To create empty arrays to fill
in with ``region``, use the `XarrayDatasetWriter` directly.
``append_dim`` at the same time.
"""

as_dataset = make_dataset(obj)
Expand All @@ -295,7 +320,9 @@ def to_icechunk(
# write in-memory arrays
writer.write_eager()
# eagerly write dask arrays
writer.write_lazy(chunkmanager_store_kwargs=chunkmanager_store_kwargs)
return writer.write_lazy(
compute=compute, chunkmanager_store_kwargs=chunkmanager_store_kwargs
)


@overload
Expand Down
33 changes: 29 additions & 4 deletions icechunk-python/tests/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,50 @@

import dask
import distributed
import xarray as xr
from icechunk import Repository, local_filesystem_storage
from icechunk.xarray import to_icechunk
from tests.test_xarray import create_test_data, roundtrip
from xarray.testing import assert_identical


def test_distributed() -> None:
@pytest.mark.parametrize("compute", [False, True])
def test_distributed(compute: bool) -> None:
with distributed.Client(): # type: ignore [no-untyped-call]
ds = create_test_data().chunk(dim1=3, dim2=4)
with roundtrip(ds) as actual:
with roundtrip(ds, compute=compute) as actual:
assert_identical(actual, ds)

# with pytest.raises(ValueError, match="Session cannot be serialized"):
# with roundtrip(ds, allow_distributed_write=False) as actual:
# pass


def test_threaded() -> None:
def test_distributed_workflow(tmpdir):
repo = Repository.create(
storage=local_filesystem_storage(tmpdir),
)
with distributed.Client(): # type: ignore [no-untyped-call]
ds = create_test_data().chunk(dim1=3, dim2=4)
session = repo.writable_session("main")
future = to_icechunk(ds, session.store, compute=False, mode="w")

session = repo.writable_session("main")
(new_session,) = dask.compute(future)
session.merge(new_session)
session.commit("foo")

roundtripped = xr.open_zarr(
repo.readonly_session(branch="main").store, consolidated=False
)
assert_identical(roundtripped, ds)


@pytest.mark.parametrize("compute", [False, True])
def test_threaded(compute: bool) -> None:
with dask.config.set(scheduler="threads"):
ds = create_test_data().chunk(dim1=3, dim2=4)
with roundtrip(ds) as actual:
with roundtrip(ds, compute=compute) as actual:
assert_identical(actual, ds)
# with roundtrip(ds, allow_distributed_write=False) as actual:
# assert_identical(actual, ds)
12 changes: 10 additions & 2 deletions icechunk-python/tests/test_xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,20 @@ def create_test_data(


@contextlib.contextmanager
def roundtrip(data: xr.Dataset) -> Generator[xr.Dataset, None, None]:
def roundtrip(
data: xr.Dataset, compute: bool = True
) -> Generator[xr.Dataset, None, None]:
"""Since this roundtrips and returns a Dataset‚the compute kwarg only controls what is passed to to_icechunk.
If False, we eagerly compute here prior to loading data"""
with tempfile.TemporaryDirectory() as tmpdir:
repo = Repository.create(local_filesystem_storage(tmpdir))
session = repo.writable_session("main")
to_icechunk(data, store=session.store, mode="w")
ret = to_icechunk(data, store=session.store, mode="w", compute=compute)
if not compute:
import dask

(new_session,) = dask.compute(ret)
session.merge(new_session)
# if allow_distributed_write:
# with session.allow_distributed_write():
# to_icechunk(data, store=session.store, mode="w")
Expand Down

0 comments on commit bd25331

Please sign in to comment.