diff --git a/esmf_regrid/experimental/unstructured_scheme.py b/esmf_regrid/experimental/unstructured_scheme.py index e197f613..4a2bd6ab 100644 --- a/esmf_regrid/experimental/unstructured_scheme.py +++ b/esmf_regrid/experimental/unstructured_scheme.py @@ -63,22 +63,34 @@ def _map_complete_blocks(src, func, dims, out_sizes): num_dims = len(dims) num_out = len(out_sizes) dropped_dims = [] + new_axis = None if num_out > num_dims: # While this code should be robust for cases where num_out > num_dims > 1, # there is some ambiguity as to what their behaviour ought to be. # Since these cases are out of our own scope, we explicitly ignore them # for the time being. assert num_dims == 1 + # While this code should be robust for cases where num_out > 2, + # we expect to handle at most 2D grids. + # Since these cases are out of our own scope, we explicitly ignore them + # for the time being. + assert num_out == 2 slice_index = sorted_dims[-1] # Insert the remaining contents of out_sizes in the position immediately # after the last dimension. out_chunks[slice_index:slice_index] = out_sizes[num_dims:] + new_axis = range(slice_index, slice_index + num_out - num_dims) elif num_dims > num_out: # While this code should be robust for cases where num_dims > num_out > 1, # there is some ambiguity as to what their behaviour ought to be. # Since these cases are out of our own scope, we explicitly ignore them # for the time being. assert num_out == 1 + # While this code should be robust for cases where num_dims > 2, + # we expect to handle at most 2D grids. + # Since these cases are out of our own scope, we explicitly ignore them + # for the time being. + assert num_dims == 2 dropped_dims = sorted_dims[num_out:] # Remove the remaining dimensions from the expected output shape. for dim in dropped_dims[::-1]: @@ -87,7 +99,11 @@ def _map_complete_blocks(src, func, dims, out_sizes): pass return data.map_blocks( - func, chunks=out_chunks, drop_axis=dropped_dims, dtype=src.dtype + func, + chunks=out_chunks, + drop_axis=dropped_dims, + new_axis=new_axis, + dtype=src.dtype, ) diff --git a/esmf_regrid/tests/unit/experimental/unstructured_scheme/test_MeshToGridESMFRegridder.py b/esmf_regrid/tests/unit/experimental/unstructured_scheme/test_MeshToGridESMFRegridder.py index 9546ac23..d35e79e2 100644 --- a/esmf_regrid/tests/unit/experimental/unstructured_scheme/test_MeshToGridESMFRegridder.py +++ b/esmf_regrid/tests/unit/experimental/unstructured_scheme/test_MeshToGridESMFRegridder.py @@ -141,17 +141,20 @@ def test_laziness(): n_lons = 12 n_lats = 10 h = 4 + i = 9 lon_bounds = (-180, 180) lat_bounds = (-90, 90) mesh = _gridlike_mesh(n_lons, n_lats) - src_data = np.arange(n_lats * n_lons * h).reshape([-1, h]) - src_data = da.from_array(src_data, chunks=[15, 2]) + # Add a chunked dimension both before and after the mesh dimension. + # The leading length 1 dimension matches the example in issue #135. + src_data = np.arange(i * n_lats * n_lons * h).reshape([1, i, -1, h]) + src_data = da.from_array(src_data, chunks=[1, 3, 15, 2]) src = Cube(src_data) mesh_coord_x, mesh_coord_y = mesh.to_MeshCoords("face") - src.add_aux_coord(mesh_coord_x, 0) - src.add_aux_coord(mesh_coord_y, 0) + src.add_aux_coord(mesh_coord_x, 2) + src.add_aux_coord(mesh_coord_y, 2) tgt = _grid_cube(n_lons, n_lats, lon_bounds, lat_bounds, circular=True) rg = MeshToGridESMFRegridder(src, tgt) @@ -160,6 +163,6 @@ def test_laziness(): result = rg(src) assert result.has_lazy_data() out_chunks = result.lazy_data().chunks - expected_chunks = ((10,), (12,), (2, 2)) + expected_chunks = ((1,), (3, 3, 3), (10,), (12,), (2, 2)) assert out_chunks == expected_chunks - assert np.allclose(result.data.reshape([-1, h]), src_data) + assert np.allclose(result.data.reshape([1, i, -1, h]), src_data)