Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion esmf_regrid/experimental/unstructured_scheme.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,22 +63,34 @@ def _map_complete_blocks(src, func, dims, out_sizes):
num_dims = len(dims)
num_out = len(out_sizes)
dropped_dims = []
new_axis = None
if num_out > num_dims:
# While this code should be robust for cases where num_out > num_dims > 1,
# there is some ambiguity as to what their behaviour ought to be.
# Since these cases are out of our own scope, we explicitly ignore them
# for the time being.
assert num_dims == 1
# While this code should be robust for cases where num_out > 2,
# we expect to handle at most 2D grids.
# Since these cases are out of our own scope, we explicitly ignore them
# for the time being.
assert num_out == 2
slice_index = sorted_dims[-1]
# Insert the remaining contents of out_sizes in the position immediately
# after the last dimension.
out_chunks[slice_index:slice_index] = out_sizes[num_dims:]
new_axis = range(slice_index, slice_index + num_out - num_dims)
elif num_dims > num_out:
# While this code should be robust for cases where num_dims > num_out > 1,
# there is some ambiguity as to what their behaviour ought to be.
# Since these cases are out of our own scope, we explicitly ignore them
# for the time being.
assert num_out == 1
# While this code should be robust for cases where num_dims > 2,
# we expect to handle at most 2D grids.
# Since these cases are out of our own scope, we explicitly ignore them
# for the time being.
assert num_dims == 2
dropped_dims = sorted_dims[num_out:]
# Remove the remaining dimensions from the expected output shape.
for dim in dropped_dims[::-1]:
Expand All @@ -87,7 +99,11 @@ def _map_complete_blocks(src, func, dims, out_sizes):
pass

return data.map_blocks(
func, chunks=out_chunks, drop_axis=dropped_dims, dtype=src.dtype
func,
chunks=out_chunks,
drop_axis=dropped_dims,
new_axis=new_axis,
dtype=src.dtype,
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,17 +141,20 @@ def test_laziness():
n_lons = 12
n_lats = 10
h = 4
i = 9
lon_bounds = (-180, 180)
lat_bounds = (-90, 90)

mesh = _gridlike_mesh(n_lons, n_lats)

src_data = np.arange(n_lats * n_lons * h).reshape([-1, h])
src_data = da.from_array(src_data, chunks=[15, 2])
# Add a chunked dimension both before and after the mesh dimension.
# The leading length 1 dimension matches the example in issue #135.
src_data = np.arange(i * n_lats * n_lons * h).reshape([1, i, -1, h])
src_data = da.from_array(src_data, chunks=[1, 3, 15, 2])
src = Cube(src_data)
mesh_coord_x, mesh_coord_y = mesh.to_MeshCoords("face")
src.add_aux_coord(mesh_coord_x, 0)
src.add_aux_coord(mesh_coord_y, 0)
src.add_aux_coord(mesh_coord_x, 2)
src.add_aux_coord(mesh_coord_y, 2)
tgt = _grid_cube(n_lons, n_lats, lon_bounds, lat_bounds, circular=True)

rg = MeshToGridESMFRegridder(src, tgt)
Expand All @@ -160,6 +163,6 @@ def test_laziness():
result = rg(src)
assert result.has_lazy_data()
out_chunks = result.lazy_data().chunks
expected_chunks = ((10,), (12,), (2, 2))
expected_chunks = ((1,), (3, 3, 3), (10,), (12,), (2, 2))
assert out_chunks == expected_chunks
assert np.allclose(result.data.reshape([-1, h]), src_data)
assert np.allclose(result.data.reshape([1, i, -1, h]), src_data)