Skip to content

Commit

Permalink
Update map_blocks to use chunksizes property. (#6776)
Browse files Browse the repository at this point in the history
* Update map_blocks to use chunksizes property.

Raise nicer error if provided template has no dask arrays.

Closes #6763

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix typing

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
dcherian and pre-commit-ci[bot] authored Jul 14, 2022
1 parent f28d7f8 commit 5678b75
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 6 deletions.
12 changes: 6 additions & 6 deletions xarray/core/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,19 +373,19 @@ def _wrapper(
new_indexes = template_indexes - set(input_indexes)
indexes = {dim: input_indexes[dim] for dim in preserved_indexes}
indexes.update({k: template._indexes[k] for k in new_indexes})
output_chunks = {
output_chunks: Mapping[Hashable, tuple[int, ...]] = {
dim: input_chunks[dim] for dim in template.dims if dim in input_chunks
}

else:
# template xarray object has been provided with proper sizes and chunk shapes
indexes = dict(template._indexes)
if isinstance(template, DataArray):
output_chunks = dict(
zip(template.dims, template.chunks) # type: ignore[arg-type]
output_chunks = template.chunksizes
if not output_chunks:
raise ValueError(
"Provided template has no dask arrays. "
" Please construct a template with appropriately chunked dask arrays."
)
else:
output_chunks = dict(template.chunks)

for dim in output_chunks:
if dim in input_chunks and len(input_chunks[dim]) != len(output_chunks[dim]):
Expand Down
9 changes: 9 additions & 0 deletions xarray/tests/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -1243,6 +1243,15 @@ def sumda(da1, da2):
)
xr.testing.assert_equal((da1 + da2).sum("x"), mapped)

# bad template: not chunked
with pytest.raises(ValueError, match="Provided template has no dask arrays"):
xr.map_blocks(
lambda a, b: (a + b).sum("x"),
da1,
args=[da2],
template=da1.sum("x").compute(),
)


@pytest.mark.parametrize("obj", [make_da(), make_ds()])
def test_map_blocks_add_attrs(obj):
Expand Down

0 comments on commit 5678b75

Please sign in to comment.