diff --git a/icechunk-python/python/icechunk/xarray.py b/icechunk-python/python/icechunk/xarray.py index 5f325a554..0aaeed722 100644 --- a/icechunk-python/python/icechunk/xarray.py +++ b/icechunk-python/python/icechunk/xarray.py @@ -168,8 +168,6 @@ def write_lazy( if not self.writer.sources: return - import dask - # TODO: this is all pretty dask specific at the moment chunkmanager_store_kwargs = chunkmanager_store_kwargs or {} chunkmanager_store_kwargs["load_stored"] = False @@ -181,6 +179,10 @@ def write_lazy( compute=False, chunkmanager_store_kwargs=chunkmanager_store_kwargs ) # type: ignore[no-untyped-call] + if not compute: + # do not execute the write in this case + return + # Now we tree-reduce all changesets merged_session = stateful_store_reduce( stored_arrays, @@ -188,14 +190,10 @@ def write_lazy( chunk=extract_session, aggregate=merge_sessions, split_every=split_every, - compute=False, + compute=True, **chunkmanager_store_kwargs, ) - if compute: - (new_session,) = dask.compute(merged_session) - self.store.session.merge(new_session) - else: - return merged_session + self.store.session.merge(merged_session) def to_icechunk( @@ -211,7 +209,7 @@ def to_icechunk( chunkmanager_store_kwargs: MutableMapping[Any, Any] | None = None, split_every: int | None = None, compute: bool = True, -) -> Any: +) -> None: """ Write an Xarray object to a group of an icechunk store. @@ -283,21 +281,7 @@ def to_icechunk( Returns ------- - None if `compute` is `True`, `dask.Delayed` otherwise. - - Examples - -------- - - When `compute` is False, this functions returns a `dask.Delayed` object. Call `dask.compute` - on this object to execute the write and receive a Session object with the changes. - With distributed writes, you are responsible for merging these changes with your - `session` so that a commit can be properly executed - - >>> session = repo.writable_session("main") - >>> future = to_icechunk(dataset, session.store, ..., compute=False) - >>> (new_session,) = dask.compute(future) # execute the write - >>> session.merge(new_session) # merge in the (possibly) remote changes with the local session - >>> session.commit("wrote some data!") + None Notes ----- @@ -308,6 +292,9 @@ def to_icechunk( should be written in a separate single call to ``to_zarr()``. - Dimensions cannot be included in both ``region`` and ``append_dim`` at the same time. + + Unlike Xarray's `to_zarr`, this function does *not* return a dask Delayed object + if `compute=False`. """ as_dataset = make_dataset(obj) @@ -320,7 +307,7 @@ def to_icechunk( # write in-memory arrays writer.write_eager() # eagerly write dask arrays - return writer.write_lazy( + writer.write_lazy( compute=compute, chunkmanager_store_kwargs=chunkmanager_store_kwargs ) diff --git a/icechunk-python/tests/test_dask.py b/icechunk-python/tests/test_dask.py index 19e86a840..031e32f22 100644 --- a/icechunk-python/tests/test_dask.py +++ b/icechunk-python/tests/test_dask.py @@ -5,18 +5,14 @@ import dask import distributed -import xarray as xr -from icechunk import Repository, local_filesystem_storage -from icechunk.xarray import to_icechunk from tests.test_xarray import create_test_data, roundtrip from xarray.testing import assert_identical -@pytest.mark.parametrize("compute", [False, True]) -def test_distributed(compute: bool) -> None: +def test_distributed() -> None: with distributed.Client(): # type: ignore [no-untyped-call] ds = create_test_data().chunk(dim1=3, dim2=4) - with roundtrip(ds, compute=compute) as actual: + with roundtrip(ds) as actual: assert_identical(actual, ds) # with pytest.raises(ValueError, match="Session cannot be serialized"): @@ -24,31 +20,10 @@ def test_distributed(compute: bool) -> None: # pass -def test_distributed_workflow(tmpdir): - repo = Repository.create( - storage=local_filesystem_storage(tmpdir), - ) - with distributed.Client(): # type: ignore [no-untyped-call] - ds = create_test_data().chunk(dim1=3, dim2=4) - session = repo.writable_session("main") - future = to_icechunk(ds, session.store, compute=False, mode="w") - - session = repo.writable_session("main") - (new_session,) = dask.compute(future) - session.merge(new_session) - session.commit("foo") - - roundtripped = xr.open_zarr( - repo.readonly_session(branch="main").store, consolidated=False - ) - assert_identical(roundtripped, ds) - - -@pytest.mark.parametrize("compute", [False, True]) -def test_threaded(compute: bool) -> None: +def test_threaded() -> None: with dask.config.set(scheduler="threads"): ds = create_test_data().chunk(dim1=3, dim2=4) - with roundtrip(ds, compute=compute) as actual: + with roundtrip(ds) as actual: assert_identical(actual, ds) # with roundtrip(ds, allow_distributed_write=False) as actual: # assert_identical(actual, ds)